Skip to content

Commit 093df0c

Browse files
authored
Extract diagnostic reporting out of COM class generator (#125308)
1 parent 85149f3 commit 093df0c

7 files changed

Lines changed: 232 additions & 102 deletions

File tree

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System.Collections.Generic;
5+
using System.Collections.Immutable;
6+
using System.Linq;
7+
using Microsoft.CodeAnalysis;
8+
using Microsoft.CodeAnalysis.CSharp;
9+
using Microsoft.CodeAnalysis.CSharp.Syntax;
10+
using Microsoft.CodeAnalysis.Diagnostics;
11+
using Microsoft.CodeAnalysis.DotnetRuntime.Extensions;
12+
13+
namespace Microsoft.Interop.Analyzers;
14+
15+
[DiagnosticAnalyzer(LanguageNames.CSharp)]
16+
public sealed class ComClassGeneratorDiagnosticsAnalyzer : DiagnosticAnalyzer
17+
{
18+
public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get; } =
19+
ImmutableArray.Create(
20+
GeneratorDiagnostics.RequiresAllowUnsafeBlocks,
21+
GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier,
22+
GeneratorDiagnostics.ClassDoesNotImplementAnyGeneratedComInterface);
23+
24+
public override void Initialize(AnalysisContext context)
25+
{
26+
context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None);
27+
context.EnableConcurrentExecution();
28+
29+
context.RegisterCompilationStartAction(static context =>
30+
{
31+
bool unsafeCodeIsEnabled = context.Compilation.Options is CSharpCompilationOptions { AllowUnsafe: true };
32+
INamedTypeSymbol? generatedComClassAttributeType = context.Compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComClassAttribute);
33+
34+
// We use this type only to report warning diagnostic. We also don't report a warning if there is at least one error.
35+
// Given that with unsafe code disabled we will get an error on each declaration, we can skip
36+
// unnecessary work of getting this symbol here
37+
INamedTypeSymbol? generatedComInterfaceAttributeType = unsafeCodeIsEnabled
38+
? context.Compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute)
39+
: null;
40+
41+
context.RegisterSymbolAction(context => AnalyzeNamedType(context, unsafeCodeIsEnabled, generatedComClassAttributeType, generatedComInterfaceAttributeType), SymbolKind.NamedType);
42+
});
43+
}
44+
45+
private static void AnalyzeNamedType(SymbolAnalysisContext context, bool unsafeCodeIsEnabled, INamedTypeSymbol? generatedComClassAttributeType, INamedTypeSymbol? generatedComInterfaceAttributeType)
46+
{
47+
if (context.Symbol is not INamedTypeSymbol { TypeKind: TypeKind.Class } classToAnalyze)
48+
{
49+
return;
50+
}
51+
52+
if (!classToAnalyze.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, generatedComClassAttributeType)))
53+
{
54+
return;
55+
}
56+
57+
foreach (Diagnostic diagnostic in GetDiagnosticsForAnnotatedClass(classToAnalyze, unsafeCodeIsEnabled, generatedComInterfaceAttributeType))
58+
{
59+
context.ReportDiagnostic(diagnostic);
60+
}
61+
}
62+
63+
public static IEnumerable<Diagnostic> GetDiagnosticsForAnnotatedClass(INamedTypeSymbol annotatedClass, bool unsafeCodeIsEnabled, INamedTypeSymbol? generatedComInterfaceAttributeType)
64+
{
65+
Location location = annotatedClass.Locations.First();
66+
bool hasErrors = false;
67+
68+
if (!unsafeCodeIsEnabled)
69+
{
70+
yield return Diagnostic.Create(GeneratorDiagnostics.RequiresAllowUnsafeBlocks, location);
71+
hasErrors = true;
72+
}
73+
74+
var declarationNode = (TypeDeclarationSyntax)location.SourceTree.GetRoot().FindNode(location.SourceSpan);
75+
76+
if (!declarationNode.IsInPartialContext(out _))
77+
{
78+
yield return Diagnostic.Create(
79+
GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier,
80+
location,
81+
annotatedClass);
82+
hasErrors = true;
83+
}
84+
85+
if (hasErrors)
86+
{
87+
// If we already reported at least one error avoid stacking a warning on top of it
88+
yield break;
89+
}
90+
91+
foreach (INamedTypeSymbol iface in annotatedClass.AllInterfaces)
92+
{
93+
if (iface.GetAttributes().FirstOrDefault(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, generatedComInterfaceAttributeType)) is { } generatedComInterfaceAttribute &&
94+
GeneratedComInterfaceCompilationData.GetDataFromAttribute(generatedComInterfaceAttribute).Options.HasFlag(ComInterfaceOptions.ManagedObjectWrapper))
95+
{
96+
yield break;
97+
}
98+
}
99+
100+
// Class doesn't implement any generated COM interface. Report a warning about that
101+
yield return Diagnostic.Create(
102+
GeneratorDiagnostics.ClassDoesNotImplementAnyGeneratedComInterface,
103+
location,
104+
annotatedClass);
105+
}
106+
}

src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
using Microsoft.CodeAnalysis;
99
using Microsoft.CodeAnalysis.CSharp;
1010
using Microsoft.CodeAnalysis.CSharp.Syntax;
11+
using Microsoft.CodeAnalysis.DotnetRuntime.Extensions;
12+
using Microsoft.Interop.Analyzers;
1113
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
1214
using static Microsoft.Interop.SyntaxFactoryExtensions;
1315

@@ -18,24 +20,28 @@ public class ComClassGenerator : IIncrementalGenerator
1820
{
1921
public void Initialize(IncrementalGeneratorInitializationContext context)
2022
{
21-
var unsafeCodeIsEnabled = context.CompilationProvider.Select((comp, ct) => comp.Options is CSharpCompilationOptions { AllowUnsafe: true }); // Unsafe code enabled
2223
// Get all types with the [GeneratedComClassAttribute] attribute.
23-
var attributedClassesOrDiagnostics = context.SyntaxProvider
24+
var attributedClasses = context.SyntaxProvider
2425
.ForAttributeWithMetadataName(
2526
TypeNames.GeneratedComClassAttribute,
2627
static (node, ct) => node is ClassDeclarationSyntax,
27-
static (context, ct) => context)
28-
.Combine(unsafeCodeIsEnabled)
29-
.Select(static (data, ct) =>
28+
static (context, _) =>
3029
{
31-
var context = data.Left;
32-
var unsafeCodeIsEnabled = data.Right;
3330
var type = (INamedTypeSymbol)context.TargetSymbol;
3431
var syntax = (ClassDeclarationSyntax)context.TargetNode;
35-
return ComClassInfo.From(type, syntax, unsafeCodeIsEnabled);
36-
});
32+
var compilation = context.SemanticModel.Compilation;
33+
var unsafeCodeIsEnabled = compilation.Options is CSharpCompilationOptions { AllowUnsafe: true };
34+
INamedTypeSymbol? generatedComInterfaceAttributeType = compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute);
3735

38-
var attributedClasses = context.FilterAndReportDiagnostics(attributedClassesOrDiagnostics);
36+
// Currently all reported diagnostics are fatal to the generator
37+
if (ComClassGeneratorDiagnosticsAnalyzer.GetDiagnosticsForAnnotatedClass(type, unsafeCodeIsEnabled, generatedComInterfaceAttributeType).Any())
38+
{
39+
return null;
40+
}
41+
42+
return ComClassInfo.From(type, syntax, generatedComInterfaceAttributeType);
43+
})
44+
.Where(static info => info is not null);
3945

4046
var classInfoType = attributedClasses
4147
.Select(static (info, ct) => new ItemAndSyntaxes<ComClassInfo>(info,

src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassInfo.cs

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,12 @@ private ComClassInfo(string className, ContainingSyntaxContext containingSyntaxC
2323
ImplementedInterfacesNames = implementedInterfacesNames;
2424
}
2525

26-
public static DiagnosticOr<ComClassInfo> From(INamedTypeSymbol type, ClassDeclarationSyntax syntax, bool unsafeCodeIsEnabled)
26+
public static ComClassInfo From(INamedTypeSymbol type, ClassDeclarationSyntax syntax, INamedTypeSymbol? generatedComInterfaceAttributeType)
2727
{
28-
if (!unsafeCodeIsEnabled)
29-
{
30-
return DiagnosticOr<ComClassInfo>.From(DiagnosticInfo.Create(GeneratorDiagnostics.RequiresAllowUnsafeBlocks, syntax.Identifier.GetLocation()));
31-
}
32-
33-
if (!syntax.IsInPartialContext(out _))
34-
{
35-
return DiagnosticOr<ComClassInfo>.From(
36-
DiagnosticInfo.Create(
37-
GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier,
38-
syntax.Identifier.GetLocation(),
39-
type.ToDisplayString()));
40-
}
41-
4228
ImmutableArray<string>.Builder names = ImmutableArray.CreateBuilder<string>();
4329
foreach (INamedTypeSymbol iface in type.AllInterfaces)
4430
{
45-
AttributeData? generatedComInterfaceAttribute = iface.GetAttributes().FirstOrDefault(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute);
31+
AttributeData? generatedComInterfaceAttribute = iface.GetAttributes().FirstOrDefault(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, generatedComInterfaceAttributeType));
4632
if (generatedComInterfaceAttribute is not null)
4733
{
4834
var attributeData = GeneratedComInterfaceCompilationData.GetDataFromAttribute(generatedComInterfaceAttribute);
@@ -53,19 +39,11 @@ public static DiagnosticOr<ComClassInfo> From(INamedTypeSymbol type, ClassDeclar
5339
}
5440
}
5541

56-
if (names.Count == 0)
57-
{
58-
return DiagnosticOr<ComClassInfo>.From(DiagnosticInfo.Create(GeneratorDiagnostics.ClassDoesNotImplementAnyGeneratedComInterface,
59-
syntax.Identifier.GetLocation(),
60-
type.ToDisplayString()));
61-
}
62-
63-
return DiagnosticOr<ComClassInfo>.From(
64-
new ComClassInfo(
65-
type.ToDisplayString(),
66-
new ContainingSyntaxContext(syntax),
67-
new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList),
68-
new(names.ToImmutable())));
42+
return new ComClassInfo(
43+
type.ToDisplayString(),
44+
new ContainingSyntaxContext(syntax),
45+
new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList),
46+
new(names.ToImmutable()));
6947
}
7048

7149
public bool Equals(ComClassInfo? other)

src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/DiagnosticOr.cs

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,8 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System;
5-
using System.Collections.Generic;
65
using System.Collections.Immutable;
76
using System.Diagnostics;
8-
using System.Linq;
9-
using Microsoft.CodeAnalysis;
107

118
namespace Microsoft.Interop
129
{
@@ -88,27 +85,4 @@ public static DiagnosticOr<T> From(T value, params DiagnosticInfo[] diagnostics)
8885
return new ValueAndDiagnostic(value, ImmutableArray.Create(diagnostics));
8986
}
9087
}
91-
92-
public static class DiagnosticOrTHelperExtensions
93-
{
94-
/// <summary>
95-
/// Splits the elements of <paramref name="provider"/> into a values provider and a diagnostics provider.
96-
/// </summary>
97-
public static (IncrementalValuesProvider<T>, IncrementalValuesProvider<DiagnosticInfo>) Split<T>(this IncrementalValuesProvider<DiagnosticOr<T>> provider)
98-
{
99-
var values = provider.Where(x => x.HasValue).Select(static (x, ct) => x.Value);
100-
var diagnostics = provider.Where(x => x.HasDiagnostic).SelectMany(static (x, ct) => x.Diagnostics);
101-
return (values, diagnostics);
102-
}
103-
104-
/// <summary>
105-
/// Filters the <see cref="IncrementalValuesProvider{TValue}"/> by whether or not the is a <see cref="Diagnostic"/>, reports the diagnostics, and returns the values.
106-
/// </summary>
107-
public static IncrementalValuesProvider<T> FilterAndReportDiagnostics<T>(this IncrementalGeneratorInitializationContext ctx, IncrementalValuesProvider<DiagnosticOr<T>> diagnosticOrValues)
108-
{
109-
var (values, diagnostics) = diagnosticOrValues.Split();
110-
ctx.RegisterDiagnostics(diagnostics);
111-
return values;
112-
}
113-
}
11488
}

src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/IncrementalGeneratorInitializationContextExtensions.cs

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4-
using System;
5-
using System.Collections.Generic;
64
using System.Collections.Immutable;
75
using System.Linq;
8-
using System.Reflection;
96
using System.Text;
107
using Microsoft.CodeAnalysis;
11-
using Microsoft.CodeAnalysis.Diagnostics;
128

139
namespace Microsoft.Interop
1410
{
@@ -50,22 +46,6 @@ public static IncrementalValueProvider<StubEnvironment> CreateStubEnvironmentPro
5046
new StubEnvironment(data.Left, data.Right));
5147
}
5248

53-
public static void RegisterDiagnostics(this IncrementalGeneratorInitializationContext context, IncrementalValuesProvider<DiagnosticInfo> diagnostics)
54-
{
55-
context.RegisterSourceOutput(diagnostics.Where(diag => diag is not null), (context, diagnostic) =>
56-
{
57-
context.ReportDiagnostic(diagnostic.ToDiagnostic());
58-
});
59-
}
60-
61-
public static void RegisterDiagnostics(this IncrementalGeneratorInitializationContext context, IncrementalValuesProvider<Diagnostic> diagnostics)
62-
{
63-
context.RegisterSourceOutput(diagnostics.Where(diag => diag is not null), (context, diagnostic) =>
64-
{
65-
context.ReportDiagnostic(diagnostic);
66-
});
67-
}
68-
6949
public static void RegisterConcatenatedSyntaxOutputs<TNode>(this IncrementalGeneratorInitializationContext context, IncrementalValuesProvider<TNode> nodes, string fileName)
7050
where TNode : SyntaxNode
7151
{

0 commit comments

Comments
 (0)