Skip to content

Commit a87b03d

Browse files
Merge pull request #42 from leonvanbokhorst/Refactor-belief-update-in-DQNFEPAgent
Refactor belief update in DQNFEPAgent
2 parents 6211dbe + 5943a9d commit a87b03d

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

src/active_inference_forager/agents/dqn_fep_agent.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,19 @@ def learn(
154154
self.decay_exploration()
155155

156156
def update_belief(self, observation: np.ndarray) -> None:
157+
# Validate that observation is numeric
158+
if not np.issubdtype(observation.dtype, np.number):
159+
raise ValueError("Observation must be a numeric array.")
160+
157161
self._update_belief_recursive(self.root_belief, observation)
158162
self._regularize_beliefs()
159163

160164
def _update_belief_recursive(self, node: BeliefNode, observation: np.ndarray):
165+
# Ensure observation is a numpy array of floats
166+
observation = np.asarray(observation)
167+
if observation.dtype != node.mean.dtype:
168+
observation = observation.astype(node.mean.dtype)
169+
161170
prediction_error = observation - node.mean
162171
node.precision += (
163172
np.outer(prediction_error, prediction_error) * self.learning_rate
@@ -297,6 +306,9 @@ def process_user_input(self, user_input: str) -> np.ndarray:
297306
words
298307
) # Politeness ratio
299308

309+
# Ensure features are of type float
310+
features = features.astype(float)
311+
300312
# Debug print statements
301313
print(f"Debug: Input string: '{user_input}'")
302314
print(f"Debug: Word count: {len(words)}")

src/active_inference_forager/main.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,13 @@ def simulate_conversation(
149149

150150
next_state, reward, done = env.step(action)
151151

152-
agent.update_belief(user_input)
152+
# Process user input into numerical features
153+
# create a variable with a np.array with three random values between 0.0 and 1.0
154+
placeholder_user_input = np.random.rand(10)
155+
156+
processed_input = agent.process_user_input(placeholder_user_input) # user_input
157+
agent.update_belief(processed_input)
158+
153159
state = next_state
154160
turn += 1
155161

0 commit comments

Comments
 (0)