diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore/AGUIEndpointRouteBuilderExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore/AGUIEndpointRouteBuilderExtensions.cs index e20d1ab448..8d6df9aa93 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore/AGUIEndpointRouteBuilderExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore/AGUIEndpointRouteBuilderExtensions.cs @@ -3,7 +3,9 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; +using System.Threading.Tasks; using Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.Shared; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; @@ -13,6 +15,7 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using Microsoft.Agents.AI.Hosting; namespace Microsoft.Agents.AI.Hosting.AGUI.AspNetCore; @@ -42,6 +45,7 @@ public static IEndpointConventionBuilder MapAGUI( var jsonOptions = context.RequestServices.GetRequiredService>(); var jsonSerializerOptions = jsonOptions.Value.SerializerOptions; + var sessionStore = context.RequestServices.GetRequiredService(); var messages = input.Messages.AsChatMessages(jsonSerializerOptions); var clientTools = input.Tools?.AsAITools().ToList(); @@ -63,11 +67,17 @@ public static IEndpointConventionBuilder MapAGUI( } }; + AgentSession? session = await GetOrCreateSessionAsync(aiAgent, input.ThreadId, sessionStore, cancellationToken).ConfigureAwait(false); + // Run the agent and convert to AG-UI events - var events = aiAgent.RunStreamingAsync( + var events = RunStreamingWithSessionPersistenceAsync( + aiAgent, messages, - options: runOptions, - cancellationToken: cancellationToken) + runOptions, + session, + input.ThreadId, + sessionStore, + cancellationToken) .AsChatResponseUpdatesAsync() .FilterServerToolsFromMixedToolInvocationsAsync(clientTools, cancellationToken) .AsAGUIEventStreamAsync( @@ -80,4 +90,59 @@ public static IEndpointConventionBuilder MapAGUI( return new AGUIServerSentEventsResult(events, sseLogger); }); } + + private static async ValueTask GetOrCreateSessionAsync( + AIAgent aiAgent, + string? threadId, + AgentSessionStore sessionStore, + CancellationToken cancellationToken) + { + if (string.IsNullOrWhiteSpace(threadId)) + { + return null; + } + + return await sessionStore.GetSessionAsync(aiAgent, threadId, cancellationToken).ConfigureAwait(false); + } + + private static async ValueTask PersistSessionAsync( + AIAgent aiAgent, + string? threadId, + AgentSession? session, + AgentSessionStore sessionStore, + CancellationToken cancellationToken) + { + if (session is null || string.IsNullOrWhiteSpace(threadId)) + { + return; + } + + await sessionStore.SaveSessionAsync(aiAgent, threadId, session, cancellationToken).ConfigureAwait(false); + } + + private static async IAsyncEnumerable RunStreamingWithSessionPersistenceAsync( + AIAgent aiAgent, + IEnumerable messages, + AgentRunOptions runOptions, + AgentSession? session, + string? threadId, + AgentSessionStore sessionStore, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + try + { + await foreach (AgentResponseUpdate update in aiAgent.RunStreamingAsync( + messages, + session, + runOptions, + CancellationToken.None).ConfigureAwait(false)) + { + yield return update; + } + } + finally + { + await PersistSessionAsync(aiAgent, threadId, session, sessionStore, cancellationToken).ConfigureAwait(false); + } + } } diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore/AGUIInMemorySessionStore.cs b/dotnet/src/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore/AGUIInMemorySessionStore.cs new file mode 100644 index 0000000000..431e80a3dd --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore/AGUIInMemorySessionStore.cs @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Agents.AI.Hosting; +using Microsoft.Extensions.Caching.Memory; + +namespace Microsoft.Agents.AI.Hosting.AGUI.AspNetCore; + +/// +/// Provides an in-memory implementation for AG-UI hosted agents. +/// +/// +/// This store is intended for single-instance development and testing scenarios. Applications that need +/// durable or distributed session persistence can replace the registered +/// service with a custom implementation. +/// +public sealed class AGUIInMemorySessionStore : AgentSessionStore, IDisposable +{ + private readonly MemoryCache _cache; + private readonly MemoryCacheEntryOptions _entryOptions; + private readonly ConcurrentDictionary> _sessionInitializationTasks = new(); + + /// + /// Initializes a new instance of the class with default options. + /// + public AGUIInMemorySessionStore() + : this(options: null) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The cache options to apply. If , default options are used. + public AGUIInMemorySessionStore(AGUIInMemorySessionStoreOptions? options) + { + AGUIInMemorySessionStoreOptions resolvedOptions = options ?? new(); + this._cache = new MemoryCache(resolvedOptions.ToMemoryCacheOptions()); + this._entryOptions = resolvedOptions.ToMemoryCacheEntryOptions(); + } + + /// + public override ValueTask GetSessionAsync(AIAgent agent, string conversationId, CancellationToken cancellationToken = default) + => this.GetOrCreateSessionAsync(agent, conversationId, cancellationToken); + + /// + /// Gets the session for the specified conversation or creates a new one when none exists. + /// + /// The agent that owns the session. + /// The conversation or thread identifier. + /// The cancellation token. + /// The existing or newly created session. + public ValueTask GetOrCreateSessionAsync(AIAgent agent, string threadId, CancellationToken cancellationToken = default) + { + SessionCacheKey key = GetKey(agent, threadId); + if (this._cache.TryGetValue(key, out AgentSession? session) && session is not null) + { + return new(session); + } + + return this.GetOrCreateSessionCoreAsync(agent, key, cancellationToken); + } + + /// + public override ValueTask SaveSessionAsync(AIAgent agent, string conversationId, AgentSession session, CancellationToken cancellationToken = default) + { + this._cache.Set(GetKey(agent, conversationId), session, this._entryOptions); + return ValueTask.CompletedTask; + } + + private async ValueTask GetOrCreateSessionCoreAsync(AIAgent agent, SessionCacheKey key, CancellationToken cancellationToken) + { + TaskCompletionSource initializationTask = new(TaskCreationOptions.RunContinuationsAsynchronously); + TaskCompletionSource sharedInitializationTask = this._sessionInitializationTasks.GetOrAdd(key, initializationTask); + + if (ReferenceEquals(sharedInitializationTask, initializationTask)) + { + if (this._cache.TryGetValue(key, out AgentSession? existingSession) && existingSession is not null) + { + initializationTask.TrySetResult(existingSession); + this._sessionInitializationTasks.TryRemove(new KeyValuePair>(key, initializationTask)); + } + else + { + _ = this.InitializeSessionAsync(agent, key, initializationTask); + } + } + + return await sharedInitializationTask.Task.WaitAsync(cancellationToken).ConfigureAwait(false); + } + + private async Task CreateAndStoreSessionAsync(AIAgent agent, SessionCacheKey key) + { + AgentSession session = await agent.CreateSessionAsync(CancellationToken.None).ConfigureAwait(false); + return this._cache.GetOrCreate(key, entry => + { + entry.SetOptions(this._entryOptions); + return session; + })!; + } + + private async Task InitializeSessionAsync(AIAgent agent, SessionCacheKey key, TaskCompletionSource initializationTask) + { + try + { + AgentSession session = await this.CreateAndStoreSessionAsync(agent, key).ConfigureAwait(false); + initializationTask.TrySetResult(session); + } + catch (Exception ex) + { + initializationTask.TrySetException(ex); + } + finally + { + this._sessionInitializationTasks.TryRemove(new KeyValuePair>(key, initializationTask)); + } + } + + /// + /// Releases the underlying memory cache. + /// + public void Dispose() + { + this._cache.Dispose(); + } + + private static SessionCacheKey GetKey(AIAgent agent, string threadId) => new(agent.Id, threadId); + + private sealed record SessionCacheKey(string AgentId, string ThreadId); +} + +/// +/// Configures the default registration used by AddAGUI. +/// +public sealed class AGUIInMemorySessionStoreOptions +{ + /// + /// Gets or sets the maximum number of sessions to retain in memory. + /// + public long? SizeLimit { get; set; } = 1000; + + /// + /// Gets or sets the absolute expiration applied to cached sessions. + /// + public TimeSpan? AbsoluteExpirationRelativeToNow { get; set; } + + /// + /// Gets or sets the sliding expiration applied to cached sessions. + /// + public TimeSpan? SlidingExpiration { get; set; } = TimeSpan.FromHours(1); + + internal MemoryCacheOptions ToMemoryCacheOptions() => new() + { + SizeLimit = this.SizeLimit + }; + + internal MemoryCacheEntryOptions ToMemoryCacheEntryOptions() => new() + { + AbsoluteExpirationRelativeToNow = this.AbsoluteExpirationRelativeToNow, + SlidingExpiration = this.SlidingExpiration, + Size = 1 + }; +} \ No newline at end of file diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.csproj b/dotnet/src/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.csproj index d6169ad805..1565977149 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.csproj +++ b/dotnet/src/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.csproj @@ -19,6 +19,7 @@ + diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore/ServiceCollectionExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore/ServiceCollectionExtensions.cs index e159c0727e..f80183c0d6 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore/ServiceCollectionExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore/ServiceCollectionExtensions.cs @@ -2,8 +2,11 @@ using System; using Microsoft.Agents.AI; +using Microsoft.Agents.AI.Hosting; using Microsoft.Agents.AI.Hosting.AGUI.AspNetCore; using Microsoft.AspNetCore.Http.Json; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Options; namespace Microsoft.Extensions.DependencyInjection; @@ -22,6 +25,26 @@ public static IServiceCollection AddAGUI(this IServiceCollection services) ArgumentNullException.ThrowIfNull(services); services.Configure(options => options.SerializerOptions.TypeInfoResolverChain.Add(AGUIJsonSerializerOptions.Default.TypeInfoResolver!)); + services.AddOptions(); + services.TryAddSingleton(sp => + new AGUIInMemorySessionStore(sp.GetRequiredService>().Value)); + + return services; + } + + /// + /// Adds support for exposing instances via AG-UI and configures the default in-memory session store. + /// + /// The to configure. + /// Configures the default . + /// The for method chaining. + public static IServiceCollection AddAGUI(this IServiceCollection services, Action configureSessionStore) + { + ArgumentNullException.ThrowIfNull(services); + ArgumentNullException.ThrowIfNull(configureSessionStore); + + services.AddAGUI(); + services.Configure(configureSessionStore); return services; } diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs index d94e520420..728cd8037b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs @@ -163,6 +163,42 @@ public async Task MultiTurnConversationPreservesAllMessagesInSessionAsync() secondResponse.Messages[0].Text.Should().Be("Hello from fake agent!"); } + [Fact] + public async Task ChatClientConversationId_RestoresServerSessionAcrossRequestsAsync() + { + // Arrange + await this.SetupTestServerAsync(useStatefulMemoryAgent: true); + var chatClient = new AGUIChatClient(this._client!, "", null); + var chatOptions = new ChatOptions(); + + // Act + ChatResponse firstResponse = await chatClient.GetResponseAsync([new ChatMessage(ChatRole.User, "My name is Alice")], chatOptions); + chatOptions.ConversationId = firstResponse.ConversationId; + ChatResponse secondResponse = await chatClient.GetResponseAsync([new ChatMessage(ChatRole.User, "What is my name?")], chatOptions); + + // Assert + firstResponse.ConversationId.Should().NotBeNullOrEmpty(); + secondResponse.Text.Should().Contain("Alice"); + } + + [Fact] + public async Task AsAIAgentSession_RestoresServerSessionAcrossRunsAsync() + { + // Arrange + await this.SetupTestServerAsync(useStatefulMemoryAgent: true); + var chatClient = new AGUIChatClient(this._client!, "", null); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Stateful assistant", tools: []); + AgentSession session = await agent.CreateSessionAsync(); + + // Act + AgentResponse firstResponse = await agent.RunAsync("My name is Alice", session, new AgentRunOptions(), CancellationToken.None); + AgentResponse secondResponse = await agent.RunAsync("What is my name?", session, new AgentRunOptions(), CancellationToken.None); + + // Assert + firstResponse.Text.Should().Contain("Alice"); + secondResponse.Text.Should().Contain("Alice"); + } + [Fact] public async Task AgentSendsMultipleMessagesInOneTurnAsync() { @@ -231,14 +267,18 @@ public async Task UserSendsMultipleMessagesAtOnceAsync() response.Messages[0].Text.Should().Be("Hello from fake agent!"); } - private async Task SetupTestServerAsync(bool useMultiMessageAgent = false) + private async Task SetupTestServerAsync(bool useMultiMessageAgent = false, bool useStatefulMemoryAgent = false) { WebApplicationBuilder builder = WebApplication.CreateBuilder(); builder.WebHost.UseTestServer(); builder.Services.AddAGUI(); - if (useMultiMessageAgent) + if (useStatefulMemoryAgent) + { + builder.Services.AddSingleton(); + } + else if (useMultiMessageAgent) { builder.Services.AddSingleton(); } @@ -249,9 +289,11 @@ private async Task SetupTestServerAsync(bool useMultiMessageAgent = false) this._app = builder.Build(); - AIAgent agent = useMultiMessageAgent - ? this._app.Services.GetRequiredService() - : this._app.Services.GetRequiredService(); + AIAgent agent = useStatefulMemoryAgent + ? this._app.Services.GetRequiredService() + : useMultiMessageAgent + ? this._app.Services.GetRequiredService() + : this._app.Services.GetRequiredService(); this._app.MapAGUI("/agent", agent); @@ -441,3 +483,98 @@ public FakeAgentSession(AgentSessionStateBag stateBag) : base(stateBag) public override object? GetService(Type serviceType, object? serviceKey = null) => null; } + +[SuppressMessage("Performance", "CA1812:Avoid uninstantiated internal classes", Justification = "Instantiated via dependency injection")] +internal sealed class FakeStatefulMemoryAgent : AIAgent +{ + private const string NameStateKey = "remembered-name"; + + protected override string? IdCore => "fake-stateful-memory-agent"; + + public override string? Description => "A fake agent that stores simple memory in the session"; + + protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => + new(new FakeAgentSession()); + + protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => + new(serializedState.Deserialize(jsonSerializerOptions)!); + + protected override ValueTask SerializeSessionCoreAsync(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + { + if (session is not FakeAgentSession typedSession) + { + throw new InvalidOperationException($"The provided session type '{session.GetType().Name}' is not compatible with this agent. Only sessions of type '{nameof(FakeAgentSession)}' can be serialized by this agent."); + } + + return new(JsonSerializer.SerializeToElement(typedSession, jsonSerializerOptions)); + } + + protected override async Task RunCoreAsync( + IEnumerable messages, + AgentSession? session = null, + AgentRunOptions? options = null, + CancellationToken cancellationToken = default) + { + List updates = []; + await foreach (AgentResponseUpdate update in this.RunStreamingAsync(messages, session, options, cancellationToken).ConfigureAwait(false)) + { + updates.Add(update); + } + + return updates.ToAgentResponse(); + } + + protected override async IAsyncEnumerable RunCoreStreamingAsync( + IEnumerable messages, + AgentSession? session = null, + AgentRunOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (session is not FakeAgentSession typedSession) + { + throw new InvalidOperationException($"The provided session type '{session?.GetType().Name ?? "null"}' is not compatible with this agent. Only sessions of type '{nameof(FakeAgentSession)}' can be used by this agent."); + } + + string lastUserInput = messages.Last(m => m.Role == ChatRole.User).Text ?? string.Empty; + string responseText; + + if (lastUserInput.StartsWith("My name is ", StringComparison.OrdinalIgnoreCase)) + { + string rememberedName = lastUserInput[11..].Trim(); + typedSession.StateBag.SetValue(NameStateKey, rememberedName); + responseText = $"Nice to meet you {rememberedName}."; + } + else if (lastUserInput.Equals("What is my name?", StringComparison.OrdinalIgnoreCase)) + { + string? rememberedName = typedSession.StateBag.GetValue(NameStateKey); + responseText = rememberedName is null + ? "I do not know your name yet." + : $"Your name is {rememberedName}."; + } + else + { + responseText = "I can remember your name if you tell me 'My name is ...'."; + } + + yield return new AgentResponseUpdate + { + MessageId = Guid.NewGuid().ToString("N"), + Role = ChatRole.Assistant, + Contents = [new TextContent(responseText)] + }; + + await Task.Yield(); + } + + private sealed class FakeAgentSession : AgentSession + { + public FakeAgentSession() + { + } + + [JsonConstructor] + public FakeAgentSession(AgentSessionStateBag stateBag) : base(stateBag) + { + } + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIInMemorySessionStoreTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIInMemorySessionStoreTests.cs new file mode 100644 index 0000000000..8388c23bdc --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIInMemorySessionStoreTests.cs @@ -0,0 +1,184 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Agents.AI.Hosting.AGUI.AspNetCore; + +namespace Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests; + +/// +/// Unit tests for . +/// +public sealed class AGUIInMemorySessionStoreTests +{ + [Fact] + public async Task GetOrCreateSessionAsync_ReusesSessionForSameThreadAsync() + { + // Arrange + var store = new AGUIInMemorySessionStore(); + var agent = new CountingAgent(); + + // Act + AgentSession firstSession = await store.GetOrCreateSessionAsync(agent, "thread-1"); + AgentSession secondSession = await store.GetOrCreateSessionAsync(agent, "thread-1"); + AgentSession thirdSession = await store.GetOrCreateSessionAsync(agent, "thread-2"); + + // Assert + Assert.Same(firstSession, secondSession); + Assert.NotSame(firstSession, thirdSession); + Assert.Equal(2, agent.CreateSessionCallCount); + } + + [Fact] + public async Task GetOrCreateSessionAsync_UsesAgentIdentityInCacheKeyAsync() + { + // Arrange + var store = new AGUIInMemorySessionStore(); + var firstAgent = new CountingAgent("agent-1"); + var secondAgent = new CountingAgent("agent-2"); + + // Act + AgentSession firstSession = await store.GetOrCreateSessionAsync(firstAgent, "shared-thread"); + AgentSession secondSession = await store.GetOrCreateSessionAsync(secondAgent, "shared-thread"); + + // Assert + Assert.NotSame(firstSession, secondSession); + Assert.Equal(1, firstAgent.CreateSessionCallCount); + Assert.Equal(1, secondAgent.CreateSessionCallCount); + } + + [Fact] + public async Task GetOrCreateSessionAsync_DoesNotCollideWhenKeyPartsContainColonAsync() + { + // Arrange + var store = new AGUIInMemorySessionStore(); + var firstAgent = new CountingAgent("a:b"); + var secondAgent = new CountingAgent("a"); + + // Act + AgentSession firstSession = await store.GetOrCreateSessionAsync(firstAgent, "c"); + AgentSession secondSession = await store.GetOrCreateSessionAsync(secondAgent, "b:c"); + + // Assert + Assert.NotSame(firstSession, secondSession); + Assert.Equal(1, firstAgent.CreateSessionCallCount); + Assert.Equal(1, secondAgent.CreateSessionCallCount); + } + + [Fact] + public async Task GetOrCreateSessionAsync_RecreatesExpiredSessionAsync() + { + // Arrange + var store = new AGUIInMemorySessionStore(new AGUIInMemorySessionStoreOptions + { + SlidingExpiration = TimeSpan.FromMilliseconds(20) + }); + var agent = new CountingAgent(); + + // Act + AgentSession firstSession = await store.GetOrCreateSessionAsync(agent, "thread-1"); + await Task.Delay(TimeSpan.FromMilliseconds(80)); + AgentSession secondSession = await store.GetOrCreateSessionAsync(agent, "thread-1"); + + // Assert + Assert.NotSame(firstSession, secondSession); + Assert.Equal(2, agent.CreateSessionCallCount); + } + + [Fact] + public async Task GetOrCreateSessionAsync_CreatesSingleSessionUnderConcurrencyAsync() + { + // Arrange + var store = new AGUIInMemorySessionStore(); + var agent = new BlockingAgent(); + + // Act + Task firstTask = store.GetOrCreateSessionAsync(agent, "thread-1").AsTask(); + await agent.SessionCreationStarted.Task; + + Task secondTask = store.GetOrCreateSessionAsync(agent, "thread-1").AsTask(); + Task thirdTask = store.GetOrCreateSessionAsync(agent, "thread-1").AsTask(); + + agent.AllowSessionCreation.TrySetResult(); + AgentSession[] sessions = await Task.WhenAll(firstTask, secondTask, thirdTask); + + // Assert + Assert.Equal(1, agent.CreateSessionCallCount); + Assert.Same(sessions[0], sessions[1]); + Assert.Same(sessions[1], sessions[2]); + } + + private sealed class CountingAgent : AIAgent + { + private readonly string _id; + + public CountingAgent(string id = "counting-agent") + { + this._id = id; + } + + public int CreateSessionCallCount { get; private set; } + + protected override string? IdCore => this._id; + + protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) + { + this.CreateSessionCallCount++; + return new(new CountingSession()); + } + + protected override ValueTask SerializeSessionCoreAsync(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + => throw new NotImplementedException(); + + protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + => throw new NotImplementedException(); + + protected override Task RunCoreAsync(IEnumerable messages, AgentSession? session = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + => throw new NotImplementedException(); + + protected override IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentSession? session = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + => throw new NotImplementedException(); + + private sealed class CountingSession : AgentSession + { + } + } + + private sealed class BlockingAgent : AIAgent + { + public TaskCompletionSource SessionCreationStarted { get; } = new(TaskCreationOptions.RunContinuationsAsynchronously); + + public TaskCompletionSource AllowSessionCreation { get; } = new(TaskCreationOptions.RunContinuationsAsynchronously); + + public int CreateSessionCallCount { get; private set; } + + protected override string? IdCore => "blocking-agent"; + + protected override async ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) + { + this.CreateSessionCallCount++; + this.SessionCreationStarted.TrySetResult(); + await this.AllowSessionCreation.Task.ConfigureAwait(false); + return new BlockingSession(); + } + + protected override ValueTask SerializeSessionCoreAsync(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + => throw new NotImplementedException(); + + protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + => throw new NotImplementedException(); + + protected override Task RunCoreAsync(IEnumerable messages, AgentSession? session = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + => throw new NotImplementedException(); + + protected override IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentSession? session = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + => throw new NotImplementedException(); + + private sealed class BlockingSession : AgentSession + { + } + } +} \ No newline at end of file diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIServiceCollectionExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIServiceCollectionExtensionsTests.cs new file mode 100644 index 0000000000..5a786c5985 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIServiceCollectionExtensionsTests.cs @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Agents.AI.Hosting; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; + +namespace Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests; + +/// +/// Unit tests for . +/// +public sealed class AGUIServiceCollectionExtensionsTests +{ + [Fact] + public void AddAGUI_RegistersDefaultAgentSessionStore() + { + // Arrange + ServiceCollection services = new(); + + // Act + services.AddAGUI(); + using ServiceProvider serviceProvider = services.BuildServiceProvider(); + + // Assert + AgentSessionStore sessionStore = serviceProvider.GetRequiredService(); + Assert.IsType(sessionStore); + } + + [Fact] + public void AddAGUI_DoesNotOverrideCustomAgentSessionStore() + { + // Arrange + ServiceCollection services = new(); + RecordingAgentSessionStore sessionStore = new(); + services.AddSingleton(sessionStore); + + // Act + services.AddAGUI(); + using ServiceProvider serviceProvider = services.BuildServiceProvider(); + + // Assert + AgentSessionStore resolvedSessionStore = serviceProvider.GetRequiredService(); + Assert.Same(sessionStore, resolvedSessionStore); + } + + [Fact] + public void AddAGUI_ConfiguresDefaultInMemorySessionStoreOptions() + { + // Arrange + ServiceCollection services = new(); + + // Act + services.AddAGUI(options => + { + options.SizeLimit = 42; + options.SlidingExpiration = TimeSpan.FromMinutes(5); + }); + using ServiceProvider serviceProvider = services.BuildServiceProvider(); + + // Assert + AGUIInMemorySessionStoreOptions options = serviceProvider.GetRequiredService>().Value; + Assert.Equal(42, options.SizeLimit); + Assert.Equal(TimeSpan.FromMinutes(5), options.SlidingExpiration); + } + + private sealed class RecordingAgentSessionStore : AgentSessionStore + { + public override ValueTask SaveSessionAsync(AIAgent agent, string conversationId, AgentSession session, CancellationToken cancellationToken = default) + => ValueTask.CompletedTask; + + public override ValueTask GetSessionAsync(AIAgent agent, string conversationId, CancellationToken cancellationToken = default) + => throw new NotImplementedException(); + } +} \ No newline at end of file