Skip to content

Commit 83d27cc

Browse files
authored
Fix semantic model caching in source generator (#172)
Fixes #170.
1 parent 847e768 commit 83d27cc

File tree

4 files changed

+181
-1
lines changed

4 files changed

+181
-1
lines changed

src/DocoptNet/CodeGeneration/SourceGenerator.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,19 @@ public void Execute(GeneratorExecutionContext context)
109109
var syntaxReceiver = (SyntaxReceiver)(context.SyntaxContextReceiver ?? throw new NullReferenceException());
110110

111111
SemanticModel? model = null;
112+
SyntaxTree? modelSyntaxTree = null;
112113

113114
var docoptTypes = new List<(string? Namespace, string Name, DocoptArgumentsAttribute? ArgumentsAttribute,
114115
SourceText Help, GenerationOptions Options)>();
115116

116117
foreach (var (cds, attributeData) in syntaxReceiver.ClassAttributes)
117118
{
118-
model ??= context.Compilation.GetSemanticModel(cds.SyntaxTree);
119+
if (model is null || modelSyntaxTree != cds.SyntaxTree)
120+
{
121+
model = context.Compilation.GetSemanticModel(cds.SyntaxTree);
122+
modelSyntaxTree = cds.SyntaxTree;
123+
}
124+
119125
var symbol = model.GetDeclaredSymbol(cds) as INamedTypeSymbol;
120126
if (symbol is null)
121127
continue;

tests/DocoptNet.Tests/CodeGeneration/SourceGeneratorTests.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,32 @@ sealed partial class ProgramArguments
192192
});
193193
}
194194

195+
[Test]
196+
public void Generate_with_classes_in_separate_files()
197+
{
198+
AssertMatchesSnapshot(new[]
199+
{
200+
("File1.cs", SourceText.From(@"
201+
namespace Namespace1
202+
{
203+
[DocoptNet.DocoptArguments]
204+
sealed partial class ProgramArguments
205+
{
206+
const string Help = ""Usage: program"";
207+
}
208+
}")),
209+
("File2.cs", SourceText.From(@"
210+
namespace Namespace2
211+
{
212+
[DocoptNet.DocoptArguments]
213+
sealed partial class ProgramArguments
214+
{
215+
const string Help = ""Usage: program"";
216+
}
217+
}")),
218+
});
219+
}
220+
195221
void AssertMatchesSnapshot((string Path, SourceText Text)[] sources,
196222
[CallerMemberName]string? callerName = null)
197223
{
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#nullable enable annotations
2+
3+
using System.Collections;
4+
using System.Collections.Generic;
5+
using DocoptNet;
6+
using DocoptNet.Internals;
7+
using Leaves = DocoptNet.Internals.ReadOnlyList<DocoptNet.Internals.LeafPattern>;
8+
9+
namespace Namespace1
10+
{
11+
partial class ProgramArguments : IEnumerable<KeyValuePair<string, object?>>
12+
{
13+
public const string Usage = "Usage: program";
14+
15+
static readonly IBaselineParser<ProgramArguments> Parser = GeneratedSourceModule.CreateParser(Help, Parse);
16+
17+
public static IBaselineParser<ProgramArguments> CreateParser() => Parser;
18+
19+
static IParser<ProgramArguments>.IResult Parse(IEnumerable<string> args, ParseFlags flags, string? version)
20+
{
21+
var options = new List<Option>
22+
{
23+
};
24+
25+
return GeneratedSourceModule.Parse(Help, Usage, args, options, flags, version, Parse);
26+
27+
static IParser<ProgramArguments>.IResult Parse(Leaves left)
28+
{
29+
var required = new RequiredMatcher(1, left, new Leaves());
30+
Match(ref required);
31+
if (!required.Result || required.Left.Count > 0)
32+
{
33+
return GeneratedSourceModule.CreateInputErrorResult<ProgramArguments>(string.Empty, Usage);
34+
}
35+
var collected = required.Collected;
36+
var result = new ProgramArguments();
37+
38+
return GeneratedSourceModule.CreateArgumentsResult(result);
39+
}
40+
41+
static void Match(ref RequiredMatcher required)
42+
{
43+
// Required(Required())
44+
var a = new RequiredMatcher(1, required.Left, required.Collected);
45+
while (a.Next())
46+
{
47+
// Required()
48+
var b = new RequiredMatcher(0, a.Left, a.Collected);
49+
while (b.Next())
50+
{
51+
if (!b.LastMatched)
52+
{
53+
break;
54+
}
55+
}
56+
a.Fold(b.Result);
57+
if (!a.LastMatched)
58+
{
59+
break;
60+
}
61+
}
62+
required.Fold(a.Result);
63+
}
64+
}
65+
66+
IEnumerator<KeyValuePair<string, object?>> GetEnumerator()
67+
{
68+
yield break;
69+
}
70+
71+
IEnumerator<KeyValuePair<string, object?>> IEnumerable<KeyValuePair<string, object?>>.GetEnumerator() => GetEnumerator();
72+
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
73+
}
74+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#nullable enable annotations
2+
3+
using System.Collections;
4+
using System.Collections.Generic;
5+
using DocoptNet;
6+
using DocoptNet.Internals;
7+
using Leaves = DocoptNet.Internals.ReadOnlyList<DocoptNet.Internals.LeafPattern>;
8+
9+
namespace Namespace2
10+
{
11+
partial class ProgramArguments : IEnumerable<KeyValuePair<string, object?>>
12+
{
13+
public const string Usage = "Usage: program";
14+
15+
static readonly IBaselineParser<ProgramArguments> Parser = GeneratedSourceModule.CreateParser(Help, Parse);
16+
17+
public static IBaselineParser<ProgramArguments> CreateParser() => Parser;
18+
19+
static IParser<ProgramArguments>.IResult Parse(IEnumerable<string> args, ParseFlags flags, string? version)
20+
{
21+
var options = new List<Option>
22+
{
23+
};
24+
25+
return GeneratedSourceModule.Parse(Help, Usage, args, options, flags, version, Parse);
26+
27+
static IParser<ProgramArguments>.IResult Parse(Leaves left)
28+
{
29+
var required = new RequiredMatcher(1, left, new Leaves());
30+
Match(ref required);
31+
if (!required.Result || required.Left.Count > 0)
32+
{
33+
return GeneratedSourceModule.CreateInputErrorResult<ProgramArguments>(string.Empty, Usage);
34+
}
35+
var collected = required.Collected;
36+
var result = new ProgramArguments();
37+
38+
return GeneratedSourceModule.CreateArgumentsResult(result);
39+
}
40+
41+
static void Match(ref RequiredMatcher required)
42+
{
43+
// Required(Required())
44+
var a = new RequiredMatcher(1, required.Left, required.Collected);
45+
while (a.Next())
46+
{
47+
// Required()
48+
var b = new RequiredMatcher(0, a.Left, a.Collected);
49+
while (b.Next())
50+
{
51+
if (!b.LastMatched)
52+
{
53+
break;
54+
}
55+
}
56+
a.Fold(b.Result);
57+
if (!a.LastMatched)
58+
{
59+
break;
60+
}
61+
}
62+
required.Fold(a.Result);
63+
}
64+
}
65+
66+
IEnumerator<KeyValuePair<string, object?>> GetEnumerator()
67+
{
68+
yield break;
69+
}
70+
71+
IEnumerator<KeyValuePair<string, object?>> IEnumerable<KeyValuePair<string, object?>>.GetEnumerator() => GetEnumerator();
72+
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
73+
}
74+
}

0 commit comments

Comments
 (0)