Skip to content

Commit c01c309

Browse files
authored
Implement presence API for IServiceConnectionContainer (#2125)
1 parent e74f0cf commit c01c309

13 files changed

Lines changed: 334 additions & 47 deletions

File tree

src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnectionManager.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) Microsoft. All rights reserved.
1+
// Copyright (c) Microsoft. All rights reserved.
22
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
33

44
using System;
@@ -168,4 +168,9 @@ private IEnumerable<IServiceConnectionContainer> GetConnections()
168168
}
169169
}
170170
}
171+
172+
public IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null, CancellationToken token = default)
173+
{
174+
throw new NotImplementedException();
175+
}
171176
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
3+
4+
using System.Collections.Generic;
5+
using System.Threading;
6+
7+
using Microsoft.Azure.SignalR.Protocol;
8+
9+
namespace Microsoft.Azure.SignalR;
10+
11+
/// <summary>
12+
/// Manager for presence operations.
13+
/// </summary>
14+
internal interface IPresenceManager
15+
{
16+
IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null, CancellationToken token = default);
17+
}

src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionContainer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
namespace Microsoft.Azure.SignalR;
88

9-
internal interface IServiceConnectionContainer : IServiceConnectionManager, IDisposable
9+
internal interface IServiceConnectionContainer : IServiceConnectionManager, IPresenceManager, IDisposable
1010
{
1111
ServiceConnectionStatus Status { get; }
1212

src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointMessageWriter.cs

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
// Copyright (c) Microsoft. All rights reserved.
1+
// Copyright (c) Microsoft. All rights reserved.
22
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
33

44
using System;
55
using System.Collections.Concurrent;
66
using System.Collections.Generic;
77
using System.Linq;
8+
using System.Runtime.CompilerServices;
89
using System.Threading;
910
using System.Threading.Tasks;
11+
1012
using Microsoft.Azure.SignalR.Common;
1113
using Microsoft.Azure.SignalR.Protocol;
1214
using Microsoft.Extensions.Logging;
@@ -16,7 +18,7 @@ namespace Microsoft.Azure.SignalR;
1618
/// <summary>
1719
/// A service connection container which sends message to multiple service endpoints.
1820
/// </summary>
19-
internal class MultiEndpointMessageWriter : IServiceMessageWriter
21+
internal class MultiEndpointMessageWriter : IServiceMessageWriter, IPresenceManager
2022
{
2123
private readonly ILogger _logger;
2224

@@ -55,8 +57,8 @@ public Task WriteAsync(ServiceMessage serviceMessage)
5557

5658
public Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken = default)
5759
{
58-
if (serviceMessage is CheckConnectionExistenceWithAckMessage
59-
|| serviceMessage is JoinGroupWithAckMessage
60+
if (serviceMessage is CheckConnectionExistenceWithAckMessage
61+
|| serviceMessage is JoinGroupWithAckMessage
6062
|| serviceMessage is LeaveGroupWithAckMessage)
6163
{
6264
return WriteSingleResultAckableMessage(serviceMessage, cancellationToken);
@@ -172,6 +174,44 @@ private async Task WriteSingleEndpointMessageAsync(HubServiceEndpoint endpoint,
172174
}
173175
}
174176

177+
public async IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null, [EnumeratorCancellation] CancellationToken token = default)
178+
{
179+
if (TargetEndpoints.Length == 0)
180+
{
181+
Log.NoEndpointRouted(_logger, nameof(GroupMemberQueryMessage));
182+
yield break;
183+
}
184+
if (top <= 0)
185+
{
186+
throw new ArgumentOutOfRangeException(nameof(top), "Top must be greater than 0.");
187+
}
188+
foreach (var endpoint in TargetEndpoints)
189+
{
190+
IAsyncEnumerable<GroupMember> enumerable;
191+
try
192+
{
193+
enumerable = endpoint.ConnectionContainer.ListConnectionsInGroupAsync(groupName, top, tracingId, token);
194+
}
195+
catch (ServiceConnectionNotActiveException)
196+
{
197+
Log.FailedWritingMessageToEndpoint(_logger, nameof(GroupMemberQueryMessage), null, endpoint.ToString());
198+
continue;
199+
}
200+
await foreach (var member in enumerable)
201+
{
202+
yield return member;
203+
if (top.HasValue)
204+
{
205+
top--;
206+
if (top == 0)
207+
{
208+
yield break;
209+
}
210+
}
211+
}
212+
}
213+
}
214+
175215
internal static class Log
176216
{
177217
public const string FailedWritingMessageToEndpointTemplate = "{0} message {1} is not sent to endpoint {2} because all connections to this endpoint are offline.";
@@ -211,4 +251,4 @@ public static void FailedWritingMessageToEndpoint(ILogger logger, string message
211251
_failedWritingMessageToEndpoint(logger, messageType, tracingId, endpoint, null);
212252
}
213253
}
214-
}
254+
}

src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Linq;
77
using System.Threading;
88
using System.Threading.Tasks;
9+
910
using Microsoft.Azure.SignalR.Protocol;
1011
using Microsoft.Extensions.Logging;
1112

@@ -154,6 +155,14 @@ public Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage, Cancel
154155
return CreateMessageWriter(serviceMessage).WriteAckableMessageAsync(serviceMessage, cancellationToken);
155156
}
156157

158+
159+
public IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null, CancellationToken token = default)
160+
{
161+
var targetEndpoints = _routerEndpoints.needRouter ? _router.GetEndpointsForGroup(groupName, _routerEndpoints.endpoints) : _routerEndpoints.endpoints;
162+
var messageWriter = new MultiEndpointMessageWriter(targetEndpoints?.ToList(), _loggerFactory);
163+
return messageWriter.ListConnectionsInGroupAsync(groupName, top, tracingId, token);
164+
}
165+
157166
public Task StartGetServersPing()
158167
{
159168
return Task.WhenAll(_routerEndpoints.endpoints.Select(c => c.ConnectionContainer.StartGetServersPing()));
@@ -499,4 +508,4 @@ public static void FailedRemovingConnectionForEndpoint(ILogger logger, string en
499508
_failedRemovingConnectionForEndpoint(logger, endpoint, ex);
500509
}
501510
}
502-
}
511+
}

src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) Microsoft. All rights reserved.
1+
// Copyright (c) Microsoft. All rights reserved.
22
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
33

44
using System;
@@ -8,6 +8,7 @@
88
using System.Runtime.CompilerServices;
99
using System.Threading;
1010
using System.Threading.Tasks;
11+
1112
using Microsoft.Azure.SignalR.Common;
1213
using Microsoft.Azure.SignalR.Protocol;
1314
using Microsoft.Extensions.Logging;
@@ -221,7 +222,7 @@ public virtual Task HandlePingAsync(PingMessage pingMessage)
221222

222223
public void HandleAck(AckMessage ackMessage)
223224
{
224-
_ackHandler.TriggerAck(ackMessage.AckId, (AckStatus)ackMessage.Status);
225+
_ackHandler.TriggerAck(ackMessage.AckId, (AckStatus)ackMessage.Status, ackMessage.Payload);
225226
}
226227

227228
public virtual Task WriteAsync(ServiceMessage serviceMessage)
@@ -249,6 +250,60 @@ public async Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage,
249250
return AckHandler.HandleAckStatus(ackableMessage, status);
250251
}
251252

253+
public async IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null, [EnumeratorCancellation] CancellationToken token = default)
254+
{
255+
if (string.IsNullOrWhiteSpace(groupName))
256+
{
257+
throw new ArgumentException($"'{nameof(groupName)}' cannot be null or whitespace.", nameof(groupName));
258+
}
259+
if (top != null && top <= 0)
260+
{
261+
throw new ArgumentException($"'{nameof(top)}' must be greater than 0.", nameof(top));
262+
}
263+
var message = new GroupMemberQueryMessage() { GroupName = groupName, Top = top, TracingId = tracingId };
264+
do
265+
{
266+
var response = await InvokeAsync<GroupMemberQueryResponse>(message, token);
267+
foreach (var member in response.Members)
268+
{
269+
yield return member;
270+
}
271+
if (response.ContinuationToken == null)
272+
{
273+
yield break;
274+
}
275+
if (message.Top != null)
276+
{
277+
message.Top -= response.Members.Count;
278+
}
279+
message.ContinuationToken = response.ContinuationToken;
280+
} while (true);
281+
}
282+
283+
/// <summary>
284+
/// <see cref="WriteAckableMessageAsync(ServiceMessage, CancellationToken)"/> only checks <see cref="AckMessage.Status"/> as the response,
285+
/// while this method checks <see cref="AckMessage.Payload"/> and deserialize it to <typeparamref name="T"/>.
286+
/// </summary>
287+
/// Made "interval virtual" for testing
288+
internal virtual async Task<T> InvokeAsync<T>(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) where T : notnull, new()
289+
{
290+
if (serviceMessage is not IAckableMessage ackableMessage)
291+
{
292+
throw new ArgumentException($"{nameof(serviceMessage)} is not {nameof(IAckableMessage)}");
293+
}
294+
295+
var task = _ackHandler.CreateSingleAck<T>(out var id, null, cancellationToken);
296+
ackableMessage.AckId = id;
297+
298+
// Sending regular messages completes as soon as the data leaves the outbound pipe,
299+
// whereas ackable ones complete upon full roundtrip of the message and the ack (or timeout).
300+
// Therefore sending them over different connections creates a possibility for processing them out of original order.
301+
// By sending both message types over the same connection we ensure that they are sent (and processed) in their original order.
302+
await WriteMessageAsync(serviceMessage);
303+
304+
return await task;
305+
}
306+
252307
public virtual Task OfflineAsync(GracefulShutdownMode mode, CancellationToken token)
253308
{
254309
_terminated = true;

src/Microsoft.Azure.SignalR.Common/Utilities/AckHandler.cs

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,17 @@ public Task<AckStatus> CreateSingleAck(out int id, TimeSpan? ackTimeout = defaul
5050
return info.Task;
5151
}
5252

53-
public Task<T> CreateSingleAck<T>(out int id, TimeSpan? ackTimeout = default, CancellationToken cancellationToken = default) where T : IMessagePackSerializable, new()
54-
{
55-
id = NextId();
56-
if (_disposed)
53+
public Task<T> CreateSingleAck<T>(out int id, TimeSpan? ackTimeout = default, CancellationToken cancellationToken = default) where T : notnull, new()
5754
{
58-
return Task.FromResult(new T());
55+
id = NextId();
56+
if (_disposed)
57+
{
58+
return Task.FromResult(new T());
59+
}
60+
var info = (SinglePayloadAck<T>)_acks.GetOrAdd(id, _ => new SinglePayloadAck<T>(ackTimeout ?? _defaultAckTimeout));
61+
cancellationToken.Register(info.Cancel);
62+
return info.Task.ContinueWith(task => task.Result);
5963
}
60-
var info = (IAckInfo<IMessagePackSerializable>)_acks.GetOrAdd(id, _ => new SinglePayloadAck<T>(ackTimeout ?? _defaultAckTimeout));
61-
cancellationToken.Register(info.Cancel);
62-
return info.Task.ContinueWith(task => (T)task.Result);
63-
}
6464

6565
public static bool HandleAckStatus(IAckableMessage message, AckStatus status)
6666
{
@@ -210,19 +210,19 @@ public override bool Ack(AckStatus status, ReadOnlySequence<byte>? payload = nul
210210
_tcs.TrySetResult(status);
211211
}
212212

213-
private sealed class SinglePayloadAck<T> : SingleAckInfo<IMessagePackSerializable> where T : IMessagePackSerializable, new()
214-
{
215-
public SinglePayloadAck(TimeSpan timeout) : base(timeout) { }
216-
public override bool Ack(AckStatus status, ReadOnlySequence<byte>? payload = null)
213+
private sealed class SinglePayloadAck<T> : SingleAckInfo<T> where T : notnull, new()
217214
{
218-
if (status == AckStatus.Timeout)
215+
public SinglePayloadAck(TimeSpan timeout) : base(timeout) { }
216+
public override bool Ack(AckStatus status, ReadOnlySequence<byte>? payload = null)
219217
{
220-
return _tcs.TrySetException(new TimeoutException($"Waiting for a {typeof(T).Name} response timed out."));
221-
}
222-
if (payload == null)
223-
{
224-
return _tcs.TrySetException(new InvalidDataException($"The expected payload is null."));
225-
}
218+
if (status == AckStatus.Timeout)
219+
{
220+
return _tcs.TrySetException(new TimeoutException($"Waiting for a {typeof(T).Name} response timed out."));
221+
}
222+
if (payload == null)
223+
{
224+
return _tcs.TrySetException(new InvalidDataException($"The expected payload is null."));
225+
}
226226

227227
try
228228
{

test/Microsoft.Azure.SignalR.Common.Tests/ServiceConnectionContainerBaseTests.cs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,14 @@
22
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
33

44
using System;
5+
using System.Buffers;
6+
using System.Linq;
7+
using System.Threading;
8+
using System.Threading.Tasks;
9+
using Microsoft.Azure.SignalR.Protocol;
510
using Microsoft.Azure.SignalR.Tests.Common;
611
using Microsoft.Extensions.Logging;
12+
using Moq;
713
using Xunit;
814
using Xunit.Abstractions;
915

@@ -101,4 +107,62 @@ public void TestStrongConnectionStatus()
101107
Assert.True(endpoint1.Online);
102108
}
103109
}
110+
111+
[Fact]
112+
public async Task TestInvokeAsync()
113+
{
114+
var endpoint1 = new TestHubServiceEndpoint();
115+
var conn1 = new TestServiceConnection();
116+
var scf = new TestServiceConnectionFactory(endpoint1 => conn1);
117+
var container = new WeakServiceConnectionContainer(scf, 5, endpoint1, Mock.Of<ILogger>());
118+
var queryMessage = new GroupMemberQueryMessage() { GroupName = "group" };
119+
var invokeTask = container.InvokeAsync<GroupMemberQueryResponse>(queryMessage, default);
120+
121+
var expectedResponse = new GroupMemberQueryResponse()
122+
{
123+
ContinuationToken = "abc",
124+
Members = [new() { ConnectionId = "1" }, new() { ConnectionId = "2" }]
125+
};
126+
var buffer = new ArrayBufferWriter<byte>();
127+
new ServiceProtocol().WriteMessagePayload(expectedResponse, buffer);
128+
AckHandler.Singleton.TriggerAck(queryMessage.AckId, AckStatus.Ok, new ReadOnlySequence<byte>(buffer.WrittenMemory));
129+
var response = await invokeTask;
130+
Assert.Equal(queryMessage, conn1.ReceivedMessages.Single());
131+
Assert.Equal(expectedResponse.ContinuationToken, response.ContinuationToken);
132+
Assert.True(expectedResponse.Members.SequenceEqual(response.Members));
133+
}
134+
135+
[Fact]
136+
public async Task TestListConnectionsInGroupAsync()
137+
{
138+
var conn = new TestServiceConnection();
139+
var groupName = "groupName";
140+
var top = 3;
141+
var tracingId = (ulong)1;
142+
var connectionContainerMock = new Mock<ServiceConnectionContainerBase>(
143+
new TestServiceConnectionFactory(endpoint => conn),
144+
5,
145+
new TestHubServiceEndpoint(),
146+
null,
147+
Mock.Of<ILogger>(),
148+
null);
149+
connectionContainerMock.SetupSequence(c => c.InvokeAsync<GroupMemberQueryResponse>(
150+
It.IsAny<ServiceMessage>(), It.IsAny<CancellationToken>()))
151+
.ReturnsAsync(new GroupMemberQueryResponse() { ContinuationToken = "abc", Members = [new() { ConnectionId = "1" }, new() { ConnectionId = "2" }] })
152+
.ReturnsAsync(new GroupMemberQueryResponse() { ContinuationToken = null, Members = [new() { ConnectionId = "3" }] });
153+
var enumerator = connectionContainerMock.Object
154+
.ListConnectionsInGroupAsync(groupName, top, tracingId)
155+
.GetAsyncEnumerator();
156+
Assert.True(await enumerator.MoveNextAsync());
157+
Assert.Equal("1", enumerator.Current.ConnectionId);
158+
connectionContainerMock.Verify(c => c.InvokeAsync<GroupMemberQueryResponse>(
159+
It.Is<GroupMemberQueryMessage>(m => m.GroupName == groupName && m.Top == 3 && m.TracingId == tracingId), It.IsAny<CancellationToken>()), Times.Once);
160+
connectionContainerMock.Invocations.Clear();
161+
Assert.True(await enumerator.MoveNextAsync());
162+
Assert.True(await enumerator.MoveNextAsync());
163+
Assert.Equal("3", enumerator.Current.ConnectionId);
164+
connectionContainerMock.Verify(c => c.InvokeAsync<GroupMemberQueryResponse>(
165+
It.Is<GroupMemberQueryMessage>(m => m.GroupName == groupName && m.Top == 1 && m.TracingId == tracingId && m.ContinuationToken == "abc"), It.IsAny<CancellationToken>()), Times.Once);
166+
Assert.False(await enumerator.MoveNextAsync());
167+
}
104168
}

0 commit comments

Comments
 (0)