Skip to content

Commit 2863406

Browse files
Merge pull request #44 from leonvanbokhorst/Refactoring-generic-agent,-testing,-NER-implementation
Refactor generic agent and NER implementation
2 parents 9f7d565 + dc23734 commit 2863406

File tree

6 files changed

+217
-136
lines changed

6 files changed

+217
-136
lines changed

requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ fonttools==4.54.1
44
iniconfig==2.0.0
55
kiwisolver==1.4.7
66
matplotlib==3.9.2
7-
numpy==2.1.2
7+
numpy==2.0.2
88
packaging==24.1
99
pillow==10.4.0
1010
pluggy==1.5.0
@@ -17,3 +17,6 @@ six==1.16.0
1717
pydantic
1818
torch
1919
torchvision
20+
textblob
21+
spacy==3.8.2
22+
https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl

src/active_inference_forager/agents/generic_agent.py

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from pydantic import Field, ConfigDict
77
from collections import deque
88
import random
9+
import re
10+
from textblob import TextBlob
11+
import spacy
912

1013
from active_inference_forager.agents.base_agent import BaseAgent
1114
from active_inference_forager.agents.belief_node import BeliefNode
@@ -80,9 +83,13 @@ class GenericAgent(BaseAgent):
8083
episode_lengths: List[int] = Field(default_factory=list)
8184
total_steps: int = Field(default=0)
8285

86+
# NLP model
87+
nlp: spacy.language.Language = Field(default_factory=lambda: spacy.load("en_core_web_sm"))
88+
8389
model_config = ConfigDict(arbitrary_types_allowed=True)
8490

85-
def __init__(self, state_dim: int, action_dim: int, **kwargs):
91+
def __init__(self, action_dim: int, **kwargs):
92+
state_dim = 17 # Updated to match the environment's state dimension
8693
super().__init__(state_dim=state_dim, action_dim=action_dim, **kwargs)
8794

8895
self.root_belief = BeliefNode(
@@ -94,12 +101,19 @@ def __init__(self, state_dim: int, action_dim: int, **kwargs):
94101
self.target_network.load_state_dict(self.q_network.state_dict())
95102
self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.learning_rate)
96103

97-
self.initialize_belief_and_action_space()
104+
self.action_space = [
105+
"ask_question",
106+
"provide_information",
107+
"clarify",
108+
"suggest_action",
109+
"express_empathy",
110+
"end_conversation",
111+
]
98112
self.exploration_rate = self.epsilon_start
99113

100114
def take_action(self, state: np.ndarray) -> str:
101115
if np.random.rand() < self.exploration_rate:
102-
return np.random.choice(self.action_space)
116+
return random.choice(self.action_space)
103117
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
104118
q_values = self.q_network(state_tensor)
105119
return self.action_space[q_values.argmax().item()]
@@ -274,19 +288,57 @@ def _build_belief_hierarchy(self, node: BeliefNode, level: int):
274288
self._build_belief_hierarchy(child, level + 1)
275289

276290
def process_user_input(self, user_input: str) -> np.ndarray:
277-
features = np.zeros(5)
291+
features = np.zeros(17) # Updated to match the environment's state dimension
278292

293+
# Basic text statistics
279294
words = user_input.split()
280-
features[0] = len(words)
281-
features[1] = user_input.count("?") / len(words)
282-
features[2] = user_input.count("!") / len(words)
283-
features[3] = len(user_input) / 100
284-
features[4] = sum(
285-
1
286-
for word in words
287-
if word.lower() in ["please", "thank", "thanks", "appreciate"]
288-
) / len(words)
289-
290-
features = features.astype(float)
295+
features[0] = len(words) # Word count
296+
features[1] = sum(len(word) for word in words) / max(len(words), 1) # Average word length
297+
features[2] = user_input.count("?") / max(len(words), 1) # Question mark frequency
298+
features[3] = user_input.count("!") / max(len(words), 1) # Exclamation mark frequency
299+
300+
# Sentiment analysis
301+
blob = TextBlob(user_input)
302+
features[4] = blob.sentiment.polarity # Sentiment polarity (-1 to 1)
303+
features[5] = blob.sentiment.subjectivity # Subjectivity (0 to 1)
304+
305+
# Keyword detection
306+
keywords = ["help", "explain", "understand", "confused", "clarify"]
307+
features[6] = sum(word.lower() in keywords for word in words) / max(len(words), 1)
308+
309+
# Complexity indicators
310+
features[7] = len(set(words)) / max(len(words), 1) # Lexical diversity
311+
features[8] = sum(len(word) > 6 for word in words) / max(len(words), 1) # Proportion of long words
312+
313+
# Politeness indicator
314+
polite_words = ["please", "thank", "thanks", "appreciate", "kindly"]
315+
features[9] = sum(word.lower() in polite_words for word in words) / max(len(words), 1)
316+
317+
# spaCy processing
318+
doc = self.nlp(user_input)
319+
320+
# Named Entity Recognition
321+
features[10] = len(doc.ents) / max(len(words), 1) # Named entity density
322+
323+
# Part-of-speech tagging
324+
pos_counts = {pos: 0 for pos in ['NOUN', 'VERB', 'ADJ', 'ADV']}
325+
for token in doc:
326+
if token.pos_ in pos_counts:
327+
pos_counts[token.pos_] += 1
328+
features[11] = pos_counts['NOUN'] / max(len(words), 1) # Noun density
329+
features[12] = pos_counts['VERB'] / max(len(words), 1) # Verb density
330+
331+
# Dependency parsing
332+
features[13] = len([token for token in doc if token.dep_ == 'ROOT']) / max(len(words), 1) # Main clause density
333+
334+
# Sentence complexity (using dependency parse tree depth)
335+
def tree_depth(token):
336+
return 1 + max((tree_depth(child) for child in token.children), default=0)
337+
338+
features[14] = sum(tree_depth(sent.root) for sent in doc.sents) / max(len(list(doc.sents)), 1) # Average parse tree depth
339+
340+
# Additional features to match the environment's state dimension
341+
features[15] = len([token for token in doc if token.is_stop]) / max(len(words), 1) # Stop word density
342+
features[16] = len([token for token in doc if token.is_punct]) / max(len(words), 1) # Punctuation density
291343

292344
return features

src/active_inference_forager/agents/philosophy_tutor_agent.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
from typing import Dict, List
33
from pydantic import Field
44
from active_inference_forager.agents.generic_agent import GenericAgent
5+
import spacy
56

67

78
class PhilosophyTutorAgent(GenericAgent):
89
knowledge_base: Dict[str, Dict] = Field(default_factory=dict)
10+
nlp: spacy.language.Language = Field(default_factory=lambda: spacy.load("en_core_web_sm"))
911

10-
def __init__(self, state_dim: int, action_dim: int, **kwargs):
11-
super().__init__(state_dim=state_dim, action_dim=action_dim, **kwargs)
12+
def __init__(self, action_dim: int, **kwargs):
13+
super().__init__(action_dim=action_dim, **kwargs)
1214
self.action_space = [
1315
"explain_concept",
1416
"ask_question",
@@ -67,25 +69,50 @@ def generate_response(self, action: str, state: np.ndarray) -> str:
6769
return "I'm not sure how to respond to that."
6870

6971
def explain_philosophical_concept(self, state: np.ndarray) -> str:
70-
# TODO: Implement logic to choose a concept based on the state
71-
concept = "epistemology" # Placeholder
72+
# Use the state vector to choose a concept
73+
concepts = list(self.knowledge_base['concepts'].keys())
74+
concept_index = int(state[0] * len(concepts)) % len(concepts)
75+
concept = concepts[concept_index]
7276
return f"Let me explain {concept}: {self.knowledge_base['concepts'][concept]}"
7377

7478
def ask_socratic_question(self, state: np.ndarray) -> str:
75-
# TODO: Implement logic to generate a relevant question based on the state
76-
return "What do you think it means to truly know something?"
79+
# Use the state vector to generate a relevant question
80+
questions = [
81+
"What do you think it means to truly know something?",
82+
"How can we determine what is morally right or wrong?",
83+
"What is the nature of reality, in your opinion?",
84+
"How do you think we can achieve a just society?",
85+
]
86+
question_index = int(state[1] * len(questions)) % len(questions)
87+
return questions[question_index]
7788

7889
def introduce_related_idea(self, state: np.ndarray) -> str:
79-
# TODO: Implement logic to choose a related idea based on the state
80-
return "Have you considered how this relates to the concept of free will?"
90+
# Use the state vector to choose a related idea
91+
ideas = [
92+
"free will",
93+
"consciousness",
94+
"personal identity",
95+
"the meaning of life",
96+
]
97+
idea_index = int(state[2] * len(ideas)) % len(ideas)
98+
return f"Have you considered how this relates to the concept of {ideas[idea_index]}?"
8199

82100
def provide_example(self, state: np.ndarray) -> str:
83-
# TODO: Implement logic to choose a relevant example based on the state
84-
return "For instance, consider how we use logic in everyday decision-making..."
101+
# Use the state vector to choose a relevant example
102+
examples = [
103+
"Consider how we use logic in everyday decision-making...",
104+
"Think about how ethical considerations shape our laws and social norms...",
105+
"Reflect on how our understanding of reality influences our actions...",
106+
"Examine how our beliefs about knowledge affect our learning processes...",
107+
]
108+
example_index = int(state[3] * len(examples)) % len(examples)
109+
return examples[example_index]
85110

86111
def suggest_thought_experiment(self, state: np.ndarray) -> str:
87-
# TODO: Implement logic to choose a thought experiment based on the state
88-
experiment = "The Cave" # Placeholder
112+
# Use the state vector to choose a thought experiment
113+
experiments = list(self.knowledge_base['thought_experiments'].keys())
114+
experiment_index = int(state[4] * len(experiments)) % len(experiments)
115+
experiment = experiments[experiment_index]
89116
return f"Let's explore {experiment}: {self.knowledge_base['thought_experiments'][experiment]}"
90117

91118
def acknowledge_limitation(self, state: np.ndarray) -> str:
@@ -95,9 +122,8 @@ def update_belief(self, state: np.ndarray):
95122
# We're now working directly with the state vector
96123
super().update_belief(state)
97124

98-
def process_user_input(self, state: np.ndarray) -> np.ndarray:
99-
# This method now processes the state vector instead of a string
100-
# We'll return the state as is, since it's already a numpy array
101-
return state
125+
def process_user_input(self, user_input: str) -> np.ndarray:
126+
# Use the GenericAgent's process_user_input method
127+
return super().process_user_input(user_input)
102128

103129
# ... rest of the class implementation remains the same ...

0 commit comments

Comments
 (0)