Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: save/load test for dotnet agents #5284

Merged
merged 14 commits into from
Feb 6, 2025
6 changes: 4 additions & 2 deletions dotnet/src/Microsoft.AutoGen/Contracts/AgentProxy.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentProxy.cs

using System.Text.Json;

namespace Microsoft.AutoGen.Contracts;

/// <summary>
Expand Down Expand Up @@ -55,7 +57,7 @@ private T ExecuteAndUnwrap<T>(Func<IAgentRuntime, ValueTask<T>> delegate_)
/// </summary>
/// <param name="state">A dictionary representing the state of the agent. Must be JSON serializable.</param>
/// <returns>A task representing the asynchronous operation.</returns>
public ValueTask LoadStateAsync(IDictionary<string, object> state)
public ValueTask LoadStateAsync(IDictionary<string, JsonElement> state)
{
return this.runtime.LoadAgentStateAsync(this.Id, state);
}
Expand All @@ -64,7 +66,7 @@ public ValueTask LoadStateAsync(IDictionary<string, object> state)
/// Saves the state of the agent. The result must be JSON serializable.
/// </summary>
/// <returns>A task representing the asynchronous operation, returning a dictionary containing the saved state.</returns>
public ValueTask<IDictionary<string, object>> SaveStateAsync()
public ValueTask<IDictionary<string, JsonElement>> SaveStateAsync()
{
return this.runtime.SaveAgentStateAsync(this.Id);
}
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IAgentRuntime.cs

using StateDict = System.Collections.Generic.IDictionary<string, object>;
using StateDict = System.Collections.Generic.IDictionary<string, System.Text.Json.JsonElement>;

namespace Microsoft.AutoGen.Contracts;

Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/Microsoft.AutoGen/Contracts/ISaveState.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ISaveState.cs

using StateDict = System.Collections.Generic.IDictionary<string, object>;
using StateDict = System.Collections.Generic.IDictionary<string, System.Text.Json.JsonElement>;

namespace Microsoft.AutoGen.Contracts;

Expand Down
39 changes: 22 additions & 17 deletions dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// GrpcAgentRuntime.cs

using System.Collections.Concurrent;
using System.Text.Json;
using Grpc.Core;
using Microsoft.AutoGen.Contracts;
using Microsoft.AutoGen.Protobuf;
Expand Down Expand Up @@ -319,13 +320,13 @@
public ValueTask<Contracts.AgentId> GetAgentAsync(string agent, string key = "default", bool lazy = true)
=> this.GetAgentAsync(new Contracts.AgentId(agent, key), lazy);

public async ValueTask<IDictionary<string, object>> SaveAgentStateAsync(Contracts.AgentId agentId)
public async ValueTask<IDictionary<string, JsonElement>> SaveAgentStateAsync(Contracts.AgentId agentId)
{
IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId);
return await agent.SaveStateAsync();
}

public async ValueTask LoadAgentStateAsync(Contracts.AgentId agentId, IDictionary<string, object> state)
public async ValueTask LoadAgentStateAsync(Contracts.AgentId agentId, IDictionary<string, JsonElement> state)
{
IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId);
await agent.LoadStateAsync(state);
Expand Down Expand Up @@ -375,37 +376,41 @@
return ValueTask.FromResult(new AgentProxy(agentId, this));
}

public async ValueTask<IDictionary<string, object>> SaveStateAsync()
{
Dictionary<string, object> state = new();
foreach (var agent in this._agentsContainer.LiveAgents)
{
state[agent.Id.ToString()] = await agent.SaveStateAsync();
}

return state;
}

public async ValueTask LoadStateAsync(IDictionary<string, object> state)
public async ValueTask LoadStateAsync(IDictionary<string, JsonElement> state)
{
HashSet<AgentType> registeredTypes = this._agentsContainer.RegisteredAgentTypes;

foreach (var agentIdStr in state.Keys)
{
Contracts.AgentId agentId = Contracts.AgentId.FromStr(agentIdStr);
if (state[agentIdStr] is not IDictionary<string, object> agentStateDict)

if (state[agentIdStr].ValueKind != JsonValueKind.Object)
{
throw new Exception($"Agent state for {agentId} is not a {typeof(IDictionary<string, object>)}: {state[agentIdStr].GetType()}");
throw new Exception($"Agent state for {agentId} is not a valid JSON object.");

Check warning on line 389 in dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs

View check run for this annotation

Codecov / codecov/patch

dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs#L389

Added line #L389 was not covered by tests
}

var agentState = JsonSerializer.Deserialize<IDictionary<string, JsonElement>>(state[agentIdStr].GetRawText())
?? throw new Exception($"Failed to deserialize state for {agentId}.");

Check warning on line 393 in dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs

View check run for this annotation

Codecov / codecov/patch

dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs#L393

Added line #L393 was not covered by tests

if (registeredTypes.Contains(agentId.Type))
{
IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId);
await agent.LoadStateAsync(agentStateDict);
await agent.LoadStateAsync(agentState);

Check warning on line 398 in dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs

View check run for this annotation

Codecov / codecov/patch

dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs#L398

Added line #L398 was not covered by tests
}
}
}

public async ValueTask<IDictionary<string, JsonElement>> SaveStateAsync()
{
Dictionary<string, JsonElement> state = new();

Check warning on line 405 in dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs

View check run for this annotation

Codecov / codecov/patch

dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs#L405

Added line #L405 was not covered by tests
foreach (var agent in this._agentsContainer.LiveAgents)
{
var agentState = await agent.SaveStateAsync();
state[agent.Id.ToString()] = JsonSerializer.SerializeToElement(agentState);
}
return state;
}

Check warning on line 412 in dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs

View check run for this annotation

Codecov / codecov/patch

dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs#L408-L412

Added lines #L408 - L412 were not covered by tests

public async ValueTask OnMessageAsync(Message message, CancellationToken cancellation = default)
{
switch (message.MessageCase)
Expand Down
7 changes: 4 additions & 3 deletions dotnet/src/Microsoft.AutoGen/Core/BaseAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System.Diagnostics;
using System.Reflection;
using System.Text.Json;
using Microsoft.AutoGen.Contracts;
using Microsoft.Extensions.Logging;

Expand Down Expand Up @@ -92,11 +93,11 @@
return null;
}

public virtual ValueTask<IDictionary<string, object>> SaveStateAsync()
public virtual ValueTask<IDictionary<string, JsonElement>> SaveStateAsync()
{
return ValueTask.FromResult<IDictionary<string, object>>(new Dictionary<string, object>());
return ValueTask.FromResult<IDictionary<string, JsonElement>>(new Dictionary<string, JsonElement>());

Check warning on line 98 in dotnet/src/Microsoft.AutoGen/Core/BaseAgent.cs

View check run for this annotation

Codecov / codecov/patch

dotnet/src/Microsoft.AutoGen/Core/BaseAgent.cs#L98

Added line #L98 was not covered by tests
}
public virtual ValueTask LoadStateAsync(IDictionary<string, object> state)
public virtual ValueTask LoadStateAsync(IDictionary<string, JsonElement> state)
{
return ValueTask.CompletedTask;
}
Expand Down
26 changes: 16 additions & 10 deletions dotnet/src/Microsoft.AutoGen/Core/InProcessRuntime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Collections.Concurrent;
using System.Diagnostics;
using System.Text.Json;
using Microsoft.AutoGen.Contracts;
using Microsoft.Extensions.Hosting;

Expand All @@ -12,7 +13,7 @@
{
public bool DeliverToSelf { get; set; } //= false;

Dictionary<AgentId, IHostableAgent> agentInstances = new();
internal Dictionary<AgentId, IHostableAgent> agentInstances = new();
Dictionary<string, ISubscriptionDefinition> subscriptions = new();
Dictionary<AgentType, Func<AgentId, IAgentRuntime, ValueTask<IHostableAgent>>> agentFactories = new();

Expand Down Expand Up @@ -152,13 +153,13 @@
return agent.Metadata;
}

public async ValueTask LoadAgentStateAsync(AgentId agentId, IDictionary<string, object> state)
public async ValueTask LoadAgentStateAsync(AgentId agentId, IDictionary<string, JsonElement> state)
{
IHostableAgent agent = await this.EnsureAgentAsync(agentId);
await agent.LoadStateAsync(state);
}

public async ValueTask<IDictionary<string, object>> SaveAgentStateAsync(AgentId agentId)
public async ValueTask<IDictionary<string, JsonElement>> SaveAgentStateAsync(AgentId agentId)
{
IHostableAgent agent = await this.EnsureAgentAsync(agentId);
return await agent.SaveStateAsync();
Expand Down Expand Up @@ -187,16 +188,21 @@
return ValueTask.CompletedTask;
}

public async ValueTask LoadStateAsync(IDictionary<string, object> state)
public async ValueTask LoadStateAsync(IDictionary<string, JsonElement> state)
{
foreach (var agentIdStr in state.Keys)
{
AgentId agentId = AgentId.FromStr(agentIdStr);
if (state[agentIdStr] is not IDictionary<string, object> agentState)

if (state[agentIdStr].ValueKind != JsonValueKind.Object)
{
throw new Exception($"Agent state for {agentId} is not a {typeof(IDictionary<string, object>)}: {state[agentIdStr].GetType()}");
throw new Exception($"Agent state for {agentId} is not a valid JSON object.");

Check warning on line 199 in dotnet/src/Microsoft.AutoGen/Core/InProcessRuntime.cs

View check run for this annotation

Codecov / codecov/patch

dotnet/src/Microsoft.AutoGen/Core/InProcessRuntime.cs#L199

Added line #L199 was not covered by tests
}

// Deserialize before using
var agentState = JsonSerializer.Deserialize<IDictionary<string, JsonElement>>(state[agentIdStr].GetRawText())
?? throw new Exception($"Failed to deserialize state for {agentId}.");

if (this.agentFactories.ContainsKey(agentId.Type))
{
IHostableAgent agent = await this.EnsureAgentAsync(agentId);
Expand All @@ -205,14 +211,14 @@
}
}

public async ValueTask<IDictionary<string, object>> SaveStateAsync()
public async ValueTask<IDictionary<string, JsonElement>> SaveStateAsync()
{
Dictionary<string, object> state = new();
Dictionary<string, JsonElement> state = new();
foreach (var agentId in this.agentInstances.Keys)
{
state[agentId.ToString()] = await this.agentInstances[agentId].SaveStateAsync();
var agentState = await this.agentInstances[agentId].SaveStateAsync();
state[agentId.ToString()] = JsonSerializer.SerializeToElement(agentState);
}

return state;
}

Expand Down
6 changes: 6 additions & 0 deletions dotnet/src/Microsoft.AutoGen/Core/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AssemblyInfo.cs

using System.Runtime.CompilerServices;

[assembly: InternalsVisibleTo("Microsoft.AutoGen.Core.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f1d038d0b85ae392ad72011df91e9343b0b5df1bb8080aa21b9424362d696919e0e9ac3a8bca24e283e10f7a569c6f443e1d4e3ebc84377c87ca5caa562e80f9932bf5ea91b7862b538e13b8ba91c7565cf0e8dfeccfea9c805ae3bda044170ecc7fc6f147aeeac422dd96aeb9eb1f5a5882aa650efe2958f2f8107d2038f2ab")]
83 changes: 0 additions & 83 deletions dotnet/test/Microsoft.AutoGen.Core.Tests/AgentRuntimeTests.cs

This file was deleted.

23 changes: 1 addition & 22 deletions dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) =>
return ValueTask.FromResult(agent);
});

// Ensure the agent is actually created
// Ensure the agent id is registered
AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false);

// Validate agent ID
Expand Down Expand Up @@ -146,25 +146,4 @@ await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) =>

Assert.True(agent.ReceivedItems.Count == 1);
}

[Fact]
public async Task AgentShouldSaveStateCorrectlyTest()
{
var runtime = new InProcessRuntime();
await runtime.StartAsync();

Logger<BaseAgent> logger = new(new LoggerFactory());
TestAgent agent = new TestAgent(new AgentId("TestType", "TestKey"), runtime, logger);

var state = await agent.SaveStateAsync();

// Ensure state is a dictionary
state.Should().NotBeNull();
state.Should().BeOfType<Dictionary<string, object>>();
state.Should().BeEmpty("Default SaveStateAsync should return an empty dictionary.");

// Add a sample value and verify it updates correctly
state["testKey"] = "testValue";
state.Should().ContainKey("testKey").WhoseValue.Should().Be("testValue");
}
}
Loading
Loading