Skip to content

Commit 8762715

Browse files
authored
feat(mocks): add RefStructArg<T> for ref struct parameter visibility in setup/verify API (#5003)
* feat(mocks): add non-generic AnyMatcher for ref struct arg positions * feat(mocks): add RefStructArg<T> with allows ref struct for net9.0+ * feat(mocks): source gen emits RefStructArg<T> with #if NET9_0_OR_GREATER blocks * test(mocks): update ref struct snapshot for RefStructArg<T> support * test(mocks): add RefStructArg<T> runtime tests for net9.0+ * refactor(mocks): address code review feedback on RefStructArg<T> PR - Extract EmitArgsArrayVariable helper in MockImplBuilder to deduplicate 3 identical #if NET9_0_OR_GREATER args array blocks - Move HasRefStructParams to computed property on MockMemberModel (derived from Parameters, excluded from equality) - Merge 3 pairs of GenerateTyped*Overload methods and 2 BuildCastArgs overloads into single methods with optional allNonOutParams parameter - Fix O(n^2) IndexOf in BuildCastArgs by pre-building a dictionary - Add net9.0+ equivalents of disabled tests using RefStructArg<T>.Any - Add XML remarks to RefStructArg<T> documenting limitations * docs(mocks): document RefStructArg<T> in argument matchers page
1 parent 3733777 commit 8762715

8 files changed

Lines changed: 437 additions & 42 deletions

File tree

TUnit.Mocks.SourceGenerator.Tests/Snapshots/Interface_With_RefStruct_Parameters.verified.txt

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,22 @@ namespace TUnit.Mocks.Generated
4141

4242
public void Process(global::System.ReadOnlySpan<byte> data)
4343
{
44-
_engine.HandleCall(0, "Process", global::System.Array.Empty<object?>());
44+
#if NET9_0_OR_GREATER
45+
var __args = new object?[] { null };
46+
#else
47+
var __args = global::System.Array.Empty<object?>();
48+
#endif
49+
_engine.HandleCall(0, "Process", __args);
4550
}
4651

4752
public int Parse(global::System.ReadOnlySpan<char> text)
4853
{
49-
return _engine.HandleCallWithReturn<int>(1, "Parse", global::System.Array.Empty<object?>(), default);
54+
#if NET9_0_OR_GREATER
55+
var __args = new object?[] { null };
56+
#else
57+
var __args = global::System.Array.Empty<object?>();
58+
#endif
59+
return _engine.HandleCallWithReturn<int>(1, "Parse", __args, default);
5060
}
5161

5262
public string GetName()
@@ -72,17 +82,33 @@ namespace TUnit.Mocks.Generated
7282
{
7383
public static class IBufferProcessor_MockMemberExtensions
7484
{
85+
#if NET9_0_OR_GREATER
86+
public static global::TUnit.Mocks.VoidMockMethodCall Process(this global::TUnit.Mocks.Mock<global::IBufferProcessor> mock, global::TUnit.Mocks.Arguments.RefStructArg<global::System.ReadOnlySpan<byte>> data)
87+
{
88+
var matchers = new global::TUnit.Mocks.Arguments.IArgumentMatcher[] { data.Matcher };
89+
return new global::TUnit.Mocks.VoidMockMethodCall(mock.Engine, 0, "Process", matchers);
90+
}
91+
#else
7592
public static global::TUnit.Mocks.VoidMockMethodCall Process(this global::TUnit.Mocks.Mock<global::IBufferProcessor> mock)
7693
{
7794
var matchers = global::System.Array.Empty<global::TUnit.Mocks.Arguments.IArgumentMatcher>();
7895
return new global::TUnit.Mocks.VoidMockMethodCall(mock.Engine, 0, "Process", matchers);
7996
}
97+
#endif
8098

99+
#if NET9_0_OR_GREATER
100+
public static global::TUnit.Mocks.MockMethodCall<int> Parse(this global::TUnit.Mocks.Mock<global::IBufferProcessor> mock, global::TUnit.Mocks.Arguments.RefStructArg<global::System.ReadOnlySpan<char>> text)
101+
{
102+
var matchers = new global::TUnit.Mocks.Arguments.IArgumentMatcher[] { text.Matcher };
103+
return new global::TUnit.Mocks.MockMethodCall<int>(mock.Engine, 1, "Parse", matchers);
104+
}
105+
#else
81106
public static global::TUnit.Mocks.MockMethodCall<int> Parse(this global::TUnit.Mocks.Mock<global::IBufferProcessor> mock)
82107
{
83108
var matchers = global::System.Array.Empty<global::TUnit.Mocks.Arguments.IArgumentMatcher>();
84109
return new global::TUnit.Mocks.MockMethodCall<int>(mock.Engine, 1, "Parse", matchers);
85110
}
111+
#endif
86112

87113
public static global::TUnit.Mocks.MockMethodCall<string> GetName(this global::TUnit.Mocks.Mock<global::IBufferProcessor> mock)
88114
{

TUnit.Mocks.SourceGenerator/Builders/MockImplBuilder.cs

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ private static void GenerateWrapMethodBody(CodeWriter writer, MockMemberModel me
188188
}
189189
}
190190

191-
var argsArray = GetArgsArrayExpression(method);
191+
var argsArray = EmitArgsArrayVariable(writer, method);
192192
var argPassList = GetArgPassList(method);
193193

194194
if (method.IsVoid && !method.IsAsync)
@@ -461,7 +461,7 @@ private static void GeneratePartialMethodBody(CodeWriter writer, MockMemberModel
461461
}
462462
}
463463

464-
var argsArray = GetArgsArrayExpression(method);
464+
var argsArray = EmitArgsArrayVariable(writer, method);
465465
var argPassList = GetArgPassList(method);
466466

467467
if (method.IsVoid && !method.IsAsync)
@@ -551,7 +551,7 @@ private static void GenerateEngineDispatchBody(CodeWriter writer, MockMemberMode
551551
}
552552
}
553553

554-
var argsArray = GetArgsArrayExpression(method);
554+
var argsArray = EmitArgsArrayVariable(writer, method);
555555

556556
var hasOutRef = HasOutRefParams(method);
557557

@@ -955,14 +955,32 @@ private static void EmitOutRefReadback(CodeWriter writer, MockMemberModel method
955955
}
956956
}
957957

958-
private static string GetArgsArrayExpression(MockMemberModel method)
958+
private static string EmitArgsArrayVariable(CodeWriter writer, MockMemberModel method)
959959
{
960-
// Only include non-out, non-ref-struct parameters in args array
961-
// (ref structs cannot be boxed into object?[])
962-
var matchableParams = method.Parameters.Where(p => p.Direction != ParameterDirection.Out && !p.IsRefStruct).ToList();
960+
if (!method.HasRefStructParams)
961+
return GetArgsArrayExpression(method, false);
962+
963+
writer.AppendLine("#if NET9_0_OR_GREATER");
964+
writer.AppendLine($"var __args = {GetArgsArrayExpression(method, true)};");
965+
writer.AppendLine("#else");
966+
writer.AppendLine($"var __args = {GetArgsArrayExpression(method, false)};");
967+
writer.AppendLine("#endif");
968+
return "__args";
969+
}
970+
971+
private static string GetArgsArrayExpression(MockMemberModel method, bool includeRefStructSentinels)
972+
{
973+
var nonOutParams = method.Parameters.Where(p => p.Direction != ParameterDirection.Out).ToList();
974+
if (includeRefStructSentinels)
975+
{
976+
if (nonOutParams.Count == 0) return "global::System.Array.Empty<object?>()";
977+
var args = string.Join(", ", nonOutParams.Select(p => p.IsRefStruct ? "null" : p.Name));
978+
return $"new object?[] {{ {args} }}";
979+
}
980+
var matchableParams = nonOutParams.Where(p => !p.IsRefStruct).ToList();
963981
if (matchableParams.Count == 0) return "global::System.Array.Empty<object?>()";
964-
var args = string.Join(", ", matchableParams.Select(p => p.Name));
965-
return $"new object?[] {{ {args} }}";
982+
var argsStr = string.Join(", ", matchableParams.Select(p => p.Name));
983+
return $"new object?[] {{ {argsStr} }}";
966984
}
967985

968986
/// <summary>

TUnit.Mocks.SourceGenerator/Builders/MockMembersBuilder.cs

Lines changed: 97 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -110,21 +110,23 @@ private static void GenerateUnifiedSealedClass(CodeWriter writer, MockMemberMode
110110

111111
var wrapperName = GetWrapperName(safeName, method);
112112
var matchableParams = method.Parameters.Where(p => p.Direction != ParameterDirection.Out && !p.IsRefStruct).ToList();
113+
var hasRefStructParams = method.HasRefStructParams;
114+
var allNonOutParams = method.Parameters.Where(p => p.Direction != ParameterDirection.Out).ToList();
113115

114116
// Ref struct returns use the void wrapper (can't use generic type args with ref structs)
115117
if (method.IsVoid || method.IsRefStructReturn)
116118
{
117-
GenerateVoidUnifiedClass(writer, wrapperName, matchableParams, events, method.Parameters);
119+
GenerateVoidUnifiedClass(writer, wrapperName, matchableParams, events, method.Parameters, hasRefStructParams, allNonOutParams);
118120
}
119121
else
120122
{
121-
GenerateReturnUnifiedClass(writer, wrapperName, matchableParams, setupReturnType, events, method.Parameters);
123+
GenerateReturnUnifiedClass(writer, wrapperName, matchableParams, setupReturnType, events, method.Parameters, hasRefStructParams, allNonOutParams);
122124
}
123125
}
124126

125127
private static void GenerateReturnUnifiedClass(CodeWriter writer, string wrapperName,
126128
List<MockParameterModel> nonOutParams, string returnType, EquatableArray<MockEventModel> events,
127-
EquatableArray<MockParameterModel> allParameters)
129+
EquatableArray<MockParameterModel> allParameters, bool hasRefStructParams, List<MockParameterModel> allNonOutParams)
128130
{
129131
var builderType = $"global::TUnit.Mocks.Setup.MethodSetupBuilder<{returnType}>";
130132
var hasOutRef = allParameters.Any(p => p.Direction == ParameterDirection.Out || p.Direction == ParameterDirection.Ref);
@@ -198,11 +200,30 @@ private static void GenerateReturnUnifiedClass(CodeWriter writer, string wrapper
198200
if (nonOutParams.Count >= 1)
199201
{
200202
writer.AppendLine();
201-
GenerateTypedReturnsOverload(writer, nonOutParams, returnType, wrapperName);
202-
writer.AppendLine();
203-
GenerateTypedCallbackOverload(writer, nonOutParams, wrapperName);
204-
writer.AppendLine();
205-
GenerateTypedThrowsOverload(writer, nonOutParams, wrapperName);
203+
if (hasRefStructParams)
204+
{
205+
writer.AppendLine("#if NET9_0_OR_GREATER");
206+
GenerateTypedReturnsOverload(writer, nonOutParams, returnType, wrapperName, allNonOutParams);
207+
writer.AppendLine();
208+
GenerateTypedCallbackOverload(writer, nonOutParams, wrapperName, allNonOutParams);
209+
writer.AppendLine();
210+
GenerateTypedThrowsOverload(writer, nonOutParams, wrapperName, allNonOutParams);
211+
writer.AppendLine("#else");
212+
GenerateTypedReturnsOverload(writer, nonOutParams, returnType, wrapperName);
213+
writer.AppendLine();
214+
GenerateTypedCallbackOverload(writer, nonOutParams, wrapperName);
215+
writer.AppendLine();
216+
GenerateTypedThrowsOverload(writer, nonOutParams, wrapperName);
217+
writer.AppendLine("#endif");
218+
}
219+
else
220+
{
221+
GenerateTypedReturnsOverload(writer, nonOutParams, returnType, wrapperName);
222+
writer.AppendLine();
223+
GenerateTypedCallbackOverload(writer, nonOutParams, wrapperName);
224+
writer.AppendLine();
225+
GenerateTypedThrowsOverload(writer, nonOutParams, wrapperName);
226+
}
206227
}
207228

208229
// Typed out/ref parameter setters
@@ -239,7 +260,7 @@ private static void GenerateReturnUnifiedClass(CodeWriter writer, string wrapper
239260

240261
private static void GenerateVoidUnifiedClass(CodeWriter writer, string wrapperName,
241262
List<MockParameterModel> nonOutParams, EquatableArray<MockEventModel> events,
242-
EquatableArray<MockParameterModel> allParameters)
263+
EquatableArray<MockParameterModel> allParameters, bool hasRefStructParams, List<MockParameterModel> allNonOutParams)
243264
{
244265
var builderType = "global::TUnit.Mocks.Setup.VoidMethodSetupBuilder";
245266
var hasOutRef = allParameters.Any(p => p.Direction == ParameterDirection.Out || p.Direction == ParameterDirection.Ref);
@@ -307,9 +328,24 @@ private static void GenerateVoidUnifiedClass(CodeWriter writer, string wrapperNa
307328
if (nonOutParams.Count >= 1)
308329
{
309330
writer.AppendLine();
310-
GenerateTypedCallbackOverload(writer, nonOutParams, wrapperName);
311-
writer.AppendLine();
312-
GenerateTypedThrowsOverload(writer, nonOutParams, wrapperName);
331+
if (hasRefStructParams)
332+
{
333+
writer.AppendLine("#if NET9_0_OR_GREATER");
334+
GenerateTypedCallbackOverload(writer, nonOutParams, wrapperName, allNonOutParams);
335+
writer.AppendLine();
336+
GenerateTypedThrowsOverload(writer, nonOutParams, wrapperName, allNonOutParams);
337+
writer.AppendLine("#else");
338+
GenerateTypedCallbackOverload(writer, nonOutParams, wrapperName);
339+
writer.AppendLine();
340+
GenerateTypedThrowsOverload(writer, nonOutParams, wrapperName);
341+
writer.AppendLine("#endif");
342+
}
343+
else
344+
{
345+
GenerateTypedCallbackOverload(writer, nonOutParams, wrapperName);
346+
writer.AppendLine();
347+
GenerateTypedThrowsOverload(writer, nonOutParams, wrapperName);
348+
}
313349
}
314350

315351
// Typed out/ref parameter setters
@@ -345,11 +381,11 @@ private static void GenerateVoidUnifiedClass(CodeWriter writer, string wrapperNa
345381
}
346382

347383
private static void GenerateTypedReturnsOverload(CodeWriter writer, List<MockParameterModel> nonOutParams,
348-
string returnType, string wrapperName)
384+
string returnType, string wrapperName, List<MockParameterModel>? allNonOutParams = null)
349385
{
350386
var typeList = string.Join(", ", nonOutParams.Select(p => p.FullyQualifiedType));
351387
var funcType = $"global::System.Func<{typeList}, {returnType}>";
352-
var castArgs = BuildCastArgs(nonOutParams);
388+
var castArgs = BuildCastArgs(nonOutParams, allNonOutParams);
353389

354390
writer.AppendLine("/// <summary>Configure a typed computed return value using the actual method parameters.</summary>");
355391
using (writer.Block($"public {wrapperName} Returns({funcType} factory)"))
@@ -360,11 +396,11 @@ private static void GenerateTypedReturnsOverload(CodeWriter writer, List<MockPar
360396
}
361397

362398
private static void GenerateTypedCallbackOverload(CodeWriter writer, List<MockParameterModel> nonOutParams,
363-
string wrapperName)
399+
string wrapperName, List<MockParameterModel>? allNonOutParams = null)
364400
{
365401
var typeList = string.Join(", ", nonOutParams.Select(p => p.FullyQualifiedType));
366402
var actionType = $"global::System.Action<{typeList}>";
367-
var castArgs = BuildCastArgs(nonOutParams);
403+
var castArgs = BuildCastArgs(nonOutParams, allNonOutParams);
368404

369405
writer.AppendLine("/// <summary>Execute a typed callback using the actual method parameters.</summary>");
370406
using (writer.Block($"public {wrapperName} Callback({actionType} callback)"))
@@ -375,11 +411,11 @@ private static void GenerateTypedCallbackOverload(CodeWriter writer, List<MockPa
375411
}
376412

377413
private static void GenerateTypedThrowsOverload(CodeWriter writer, List<MockParameterModel> nonOutParams,
378-
string wrapperName)
414+
string wrapperName, List<MockParameterModel>? allNonOutParams = null)
379415
{
380416
var typeList = string.Join(", ", nonOutParams.Select(p => p.FullyQualifiedType));
381417
var funcType = $"global::System.Func<{typeList}, global::System.Exception>";
382-
var castArgs = BuildCastArgs(nonOutParams);
418+
var castArgs = BuildCastArgs(nonOutParams, allNonOutParams);
383419

384420
writer.AppendLine("/// <summary>Configure a typed computed exception using the actual method parameters.</summary>");
385421
using (writer.Block($"public {wrapperName} Throws({funcType} exceptionFactory)"))
@@ -442,13 +478,32 @@ private static void GenerateTypedOutRefMethods(CodeWriter writer, EquatableArray
442478
private static string ToPascalCase(string name)
443479
=> string.IsNullOrEmpty(name) ? name : char.ToUpperInvariant(name[0]) + name[1..];
444480

445-
private static string BuildCastArgs(List<MockParameterModel> nonOutParams)
481+
private static string BuildCastArgs(List<MockParameterModel> nonOutParams, List<MockParameterModel>? allNonOutParams = null)
446482
{
447-
return string.Join(", ", nonOutParams.Select((p, i) =>
448-
$"({p.FullyQualifiedType})args[{i}]!"));
483+
if (allNonOutParams is null)
484+
return string.Join(", ", nonOutParams.Select((p, i) => $"({p.FullyQualifiedType})args[{i}]!"));
485+
486+
var indexMap = allNonOutParams.Select((p, i) => (p, i)).ToDictionary(x => x.p, x => x.i);
487+
return string.Join(", ", nonOutParams.Select(p => $"({p.FullyQualifiedType})args[{indexMap[p]}]!"));
449488
}
450489

451490
private static void GenerateMemberMethod(CodeWriter writer, MockMemberModel method, MockTypeModel model, string safeName)
491+
{
492+
if (method.HasRefStructParams)
493+
{
494+
writer.AppendLine("#if NET9_0_OR_GREATER");
495+
EmitMemberMethodBody(writer, method, model, safeName, includeRefStructArgs: true);
496+
writer.AppendLine("#else");
497+
EmitMemberMethodBody(writer, method, model, safeName, includeRefStructArgs: false);
498+
writer.AppendLine("#endif");
499+
}
500+
else
501+
{
502+
EmitMemberMethodBody(writer, method, model, safeName, includeRefStructArgs: false);
503+
}
504+
}
505+
506+
private static void EmitMemberMethodBody(CodeWriter writer, MockMemberModel method, MockTypeModel model, string safeName, bool includeRefStructArgs)
452507
{
453508
// For async methods (Task<T>/ValueTask<T>), unwrap the return type so users write .Returns(5) not .Returns(Task.FromResult(5))
454509
// For void-async methods (Task/ValueTask), IsVoid is already true
@@ -474,7 +529,7 @@ private static void GenerateMemberMethod(CodeWriter writer, MockMemberModel meth
474529
returnType = $"global::TUnit.Mocks.MockMethodCall<{setupReturnType}>";
475530
}
476531

477-
var paramList = GetArgParameterList(method);
532+
var paramList = GetArgParameterList(method, includeRefStructArgs);
478533
var typeParams = GetTypeParameterList(method);
479534
var constraints = GetConstraintClauses(method);
480535

@@ -484,9 +539,10 @@ private static void GenerateMemberMethod(CodeWriter writer, MockMemberModel meth
484539

485540
using (writer.Block($"public static {returnType} {safeMemberName}{typeParams}({fullParamList}){constraints}"))
486541
{
487-
// Build matchers array (exclude out and ref struct params)
488-
var matchableParams = method.Parameters
489-
.Where(p => p.Direction != ParameterDirection.Out && !p.IsRefStruct).ToList();
542+
// Build matchers array
543+
var matchableParams = includeRefStructArgs
544+
? method.Parameters.Where(p => p.Direction != ParameterDirection.Out).ToList()
545+
: method.Parameters.Where(p => p.Direction != ParameterDirection.Out && !p.IsRefStruct).ToList();
490546

491547
if (matchableParams.Count == 0)
492548
{
@@ -576,13 +632,23 @@ private static void GenerateRaiseExtensionMethods(CodeWriter writer, MockTypeMod
576632
}
577633
}
578634

579-
private static string GetArgParameterList(MockMemberModel method)
635+
private static string GetArgParameterList(MockMemberModel method, bool includeRefStructArgs)
580636
{
581-
// Only include non-out, non-ref-struct parameters as Arg<T> in setup
582-
// (ref structs cannot be used as generic type arguments)
583-
return string.Join(", ", method.Parameters
584-
.Where(p => p.Direction != ParameterDirection.Out && !p.IsRefStruct)
585-
.Select(p => $"global::TUnit.Mocks.Arguments.Arg<{p.FullyQualifiedType}> {p.Name}"));
637+
var parts = new List<string>();
638+
foreach (var p in method.Parameters)
639+
{
640+
if (p.Direction == ParameterDirection.Out) continue;
641+
if (p.IsRefStruct)
642+
{
643+
if (includeRefStructArgs)
644+
parts.Add($"global::TUnit.Mocks.Arguments.RefStructArg<{p.FullyQualifiedType}> {p.Name}");
645+
}
646+
else
647+
{
648+
parts.Add($"global::TUnit.Mocks.Arguments.Arg<{p.FullyQualifiedType}> {p.Name}");
649+
}
650+
}
651+
return string.Join(", ", parts);
586652
}
587653

588654
private static string GetTypeParameterList(MockMemberModel method)

TUnit.Mocks.SourceGenerator/Models/MockMemberModel.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Linq;
23

34
namespace TUnit.Mocks.SourceGenerator.Models;
45

@@ -31,6 +32,12 @@ internal sealed record MockMemberModel : IEquatable<MockMemberModel>
3132
public bool IsProtected { get; init; }
3233
public bool IsRefStructReturn { get; init; }
3334

35+
/// <summary>
36+
/// Returns true if the method has any non-out ref struct parameters.
37+
/// Computed from <see cref="Parameters"/> — does not participate in equality.
38+
/// </summary>
39+
public bool HasRefStructParams => Parameters.Any(p => p.IsRefStruct && p.Direction != ParameterDirection.Out);
40+
3441
public bool Equals(MockMemberModel? other)
3542
{
3643
if (other is null) return false;

0 commit comments

Comments
 (0)