Skip to content

Commit 8b58e59

Browse files
Mrwcontext fixes (#7998)
Fixes #7987 and #7990
1 parent 69e96b1 commit 8b58e59

File tree

16 files changed

+266
-18
lines changed

16 files changed

+266
-18
lines changed

packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ModelReaderWriterContextDefinition.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ namespace Microsoft.TypeSpec.Generator.ClientModel.Providers
1515
{
1616
internal class ModelReaderWriterContextDefinition : TypeProvider
1717
{
18-
internal static string s_name = $"{RemovePeriods(ScmCodeModelGenerator.Instance.TypeFactory.PrimaryNamespace)}Context";
18+
private static readonly string _name = $"{RemovePeriods(ScmCodeModelGenerator.Instance.TypeFactory.PrimaryNamespace)}Context";
1919

20-
protected override string BuildName() => s_name;
20+
protected override string BuildName() => _name;
2121

2222
protected override string BuildRelativeFilePath() => Path.Combine("src", "Generated", "Models", $"{Name}.cs");
2323

packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/ScmCodeModelGenerator.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ public class ScmCodeModelGenerator : CodeModelGenerator
2020
private ScmOutputLibrary? _scmOutputLibrary;
2121
public override OutputLibrary OutputLibrary => _scmOutputLibrary ??= new();
2222

23+
internal ModelReaderWriterContextDefinition MrwContextDefinition { get; } = new ModelReaderWriterContextDefinition();
24+
2325
public override ScmTypeFactory TypeFactory { get; }
2426

2527
[ImportingConstructor]
@@ -36,7 +38,8 @@ protected override void Configure()
3638
AddMetadataReference(MetadataReference.CreateFromFile(typeof(ClientResult).Assembly.Location));
3739
AddMetadataReference(MetadataReference.CreateFromFile(typeof(BinaryData).Assembly.Location));
3840
AddMetadataReference(MetadataReference.CreateFromFile(typeof(JsonSerializer).Assembly.Location));
39-
AddTypeToKeepPublic(ModelReaderWriterContextDefinition.s_name);
41+
AddTypeToKeepPublic(MrwContextDefinition);
42+
AddNonRootType(MrwContextDefinition);
4043
}
4144
}
4245
}

packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/ScmOutputLibrary.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ protected override TypeProvider[] BuildTypeProviders()
7777
new BinaryContentHelperDefinition(),
7878
new PipelineRequestHeadersExtensionsDefinition(),
7979
.. GetMultipartFormDataBinaryContentDefinition(),
80-
new ModelReaderWriterContextDefinition()
80+
ScmCodeModelGenerator.Instance.MrwContextDefinition
8181
];
8282
}
8383

packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ public async Task ExecuteAsync()
4343
await customCodeWorkspace.GetCompilationAsync(),
4444
await GeneratedCodeWorkspace.LoadBaselineContract());
4545

46+
// Configure must be called after the SourceInputModel is set.
47+
CodeModelGenerator.Instance.Configure();
48+
4649
GeneratedCodeWorkspace generatedCodeWorkspace = await GeneratedCodeWorkspace.Create();
4750

4851
var output = CodeModelGenerator.Instance.OutputLibrary;

packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CodeModelGenerator.cs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ public void AddSharedSourceDirectory(string sharedSourceDirectory)
128128
internal HashSet<string> TypesToKeep { get; } = [];
129129

130130
internal HashSet<string> TypesToKeepPublic { get; } = [];
131+
internal HashSet<string> NonRootTypes { get; } = [];
131132

132133
/// <summary>
133134
/// Adds a type to the list of types to keep.
@@ -147,7 +148,23 @@ public void AddTypeToKeep(string typeName)
147148
/// <summary>
148149
/// Adds a type to the list of types to keep as public.
149150
/// </summary>
150-
/// <param name="typeName">The type provider representing the type.</param>
151-
public void AddTypeToKeepPublic(string typeName) => TypesToKeepPublic.Add(typeName);
151+
/// <param name="type">The type provider representing the type.</param>
152+
public void AddTypeToKeepPublic(TypeProvider type) => TypesToKeepPublic.Add(type.Type.FullyQualifiedName);
153+
154+
/// <summary>
155+
/// Adds a type to the list of non-root type providers. Non root type providers are types whose
156+
/// references do not contribute to usages of the generated code. Therefore if the 'unreferenced-types-handling' property
157+
/// is not set to 'keepAll', any types referenced by non-root type providers will not automatically be kept.
158+
/// </summary>
159+
/// <param name="type">The fully qualified type name.</param>
160+
public void AddNonRootType(string type) => NonRootTypes.Add(type);
161+
162+
/// <summary>
163+
/// Adds a type to the list of non-root type providers. Non root type providers are types whose
164+
/// references do not contribute to usages of the generated code. Therefore if the 'unreferenced-types-handling' property
165+
/// is not set to 'keepAll', any types referenced by non-root type providers will not automatically be kept.
166+
/// </summary>
167+
/// <param name="type">The type provider representing the type</param>
168+
public void AddNonRootType(TypeProvider type) => NonRootTypes.Add(type.Type.FullyQualifiedName);
152169
}
153170
}

packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,12 @@ internal static Project AddDirectory(Project project, string directory, Func<str
251251
public async Task PostProcessAsync()
252252
{
253253
var modelFactory = CodeModelGenerator.Instance.OutputLibrary.ModelFactory.Value;
254+
var nonRootTypes = CodeModelGenerator.Instance.NonRootTypes;
254255
var postProcessor = new PostProcessor(
255256
[.. CodeModelGenerator.Instance.TypeFactory.UnionVariantTypesToKeep, .. CodeModelGenerator.Instance.TypesToKeep],
256-
modelFactoryFullName: $"{modelFactory.Type.Namespace}.{modelFactory.Name}");
257+
modelFactoryFullName: modelFactory.Type.FullyQualifiedName,
258+
additionalNonRootTypeFullNames: nonRootTypes);
259+
257260
switch (Configuration.UnreferencedTypesHandling)
258261
{
259262
case Configuration.UnreferencedTypesHandlingOption.KeepAll:

packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,18 @@ namespace Microsoft.TypeSpec.Generator
1616
internal class PostProcessor
1717
{
1818
private readonly string? _modelFactoryFullName;
19-
private readonly string? _aspExtensionClassName;
19+
private readonly IEnumerable<string>? _additionalNonRootTypeFullNames;
2020
private readonly HashSet<string> _typesToKeep;
2121
private INamedTypeSymbol? _modelFactorySymbol;
2222

2323
public PostProcessor(
2424
HashSet<string> typesToKeep,
2525
string? modelFactoryFullName = null,
26-
string? aspExtensionClassName = null)
26+
IEnumerable<string>? additionalNonRootTypeFullNames = null)
2727
{
2828
_typesToKeep = typesToKeep;
2929
_modelFactoryFullName = modelFactoryFullName;
30-
_aspExtensionClassName = aspExtensionClassName;
30+
_additionalNonRootTypeFullNames = additionalNonRootTypeFullNames;
3131
}
3232

3333
private record TypeSymbols(
@@ -54,10 +54,22 @@ private async Task<TypeSymbols> GetTypeSymbolsAsync(Compilation compilation,
5454
var documentCache = new Dictionary<Document, HashSet<INamedTypeSymbol>>();
5555

5656
if (_modelFactoryFullName != null)
57+
{
5758
_modelFactorySymbol = compilation.GetTypeByMetadataName(_modelFactoryFullName);
58-
INamedTypeSymbol? aspDotNetExtensionSymbol = null;
59-
if (_aspExtensionClassName != null)
60-
aspDotNetExtensionSymbol = compilation.GetTypeByMetadataName(_aspExtensionClassName);
59+
}
60+
61+
var additionalNonRootTypeSymbols = new HashSet<INamedTypeSymbol>(SymbolEqualityComparer.Default);
62+
if (_additionalNonRootTypeFullNames != null)
63+
{
64+
foreach (var typeFullName in _additionalNonRootTypeFullNames)
65+
{
66+
var typeSymbol = compilation.GetTypeByMetadataName(typeFullName);
67+
if (typeSymbol != null)
68+
{
69+
additionalNonRootTypeSymbols.Add(typeSymbol);
70+
}
71+
}
72+
}
6173

6274
foreach (var document in project.Documents)
6375
{
@@ -83,8 +95,10 @@ private async Task<TypeSymbols> GetTypeSymbolsAsync(Compilation compilation,
8395

8496
// we do not add the model factory and aspDotNetExtension symbol to the declared symbol list so that it will never be included in any process of internalization or removal
8597
if (!SymbolEqualityComparer.Default.Equals(symbol, _modelFactorySymbol)
86-
&& !SymbolEqualityComparer.Default.Equals(symbol, aspDotNetExtensionSymbol))
98+
&& !additionalNonRootTypeSymbols.Contains(symbol))
99+
{
87100
result.Add(symbol);
101+
}
88102

89103
AddInList(declarationCache, symbol, typeDeclaration);
90104
AddInList(documentCache, document, symbol,
@@ -469,7 +483,39 @@ arg.Expression is TypeOfExpressionSyntax typeOfExpr &&
469483

470484
if (invalidAttributes.Count > 0)
471485
{
486+
// Check if any invalid attribute has type-level XML docs in its leading trivia
487+
var attributeWithDocs = invalidAttributes
488+
.OrderBy(a => a.SpanStart)
489+
.FirstOrDefault(attr => attr.GetLeadingTrivia().Any(t =>
490+
t.IsKind(SyntaxKind.SingleLineDocumentationCommentTrivia) ||
491+
t.IsKind(SyntaxKind.MultiLineDocumentationCommentTrivia)));
492+
493+
SyntaxTriviaList? xmlDocs = null;
494+
if (attributeWithDocs != null)
495+
{
496+
xmlDocs = attributeWithDocs.GetLeadingTrivia()
497+
.Where(t => t.IsKind(SyntaxKind.SingleLineDocumentationCommentTrivia) ||
498+
t.IsKind(SyntaxKind.MultiLineDocumentationCommentTrivia))
499+
.ToSyntaxTriviaList();
500+
}
501+
502+
// Remove all invalid attributes without keeping trivia
472503
cu = cu.RemoveNodes(invalidAttributes, SyntaxRemoveOptions.KeepNoTrivia)!;
504+
505+
// If we found XML docs, reattach them to the type declaration
506+
if (xmlDocs?.Any() == true)
507+
{
508+
var typeDecl = cu.DescendantNodes()
509+
.OfType<TypeDeclarationSyntax>()
510+
.FirstOrDefault();
511+
512+
if (typeDecl != null)
513+
{
514+
cu = cu.ReplaceNode(typeDecl,
515+
typeDecl.WithLeadingTrivia(xmlDocs.Value.AddRange(typeDecl.GetLeadingTrivia())));
516+
}
517+
}
518+
473519
solution = solution.WithDocumentSyntaxRoot(documentId, cu);
474520
}
475521

packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/StartUp/GeneratorHandler.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ internal void SelectGenerator(CommandLineOptions options)
156156
}
157157
}
158158

159-
CodeModelGenerator.Instance.Configure();
160159
loaded = true;
161160
break;
162161
}

packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/PostProcessorTests.cs

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4+
using System.ClientModel.Primitives;
5+
using System.Collections.Generic;
46
using System.IO;
57
using System.Linq;
68
using System.Threading.Tasks;
@@ -58,6 +60,7 @@ public async Task RemovesInvalidUsings()
5860
CollectionAssert.Contains(usings, "System");
5961
}
6062

63+
6164
[Test]
6265
public async Task DoesNotRemoveValidUsings()
6366
{
@@ -102,11 +105,129 @@ public async Task DoesNotRemoveValidUsings()
102105
CollectionAssert.Contains(usings, "System");
103106
}
104107

108+
[Test]
109+
public async Task RemovesInvalidAttributes()
110+
{
111+
MockHelpers.LoadMockGenerator();
112+
var workspace = new AdhocWorkspace();
113+
var projectInfo = ProjectInfo.Create(
114+
ProjectId.CreateNewId(),
115+
VersionStamp.Create(),
116+
name: "TestProj",
117+
assemblyName: "TestProj",
118+
language: LanguageNames.CSharp)
119+
.WithMetadataReferences(new[]
120+
{
121+
MetadataReference.CreateFromFile(typeof(object).Assembly.Location),
122+
MetadataReference.CreateFromFile(typeof(ModelReaderWriterBuildableAttribute).Assembly.Location)
123+
});
124+
125+
var project = workspace.AddProject(projectInfo);
126+
var folder = Helpers.GetAssetFileOrDirectoryPath(false);
127+
const string removesInvalidAttributesFileName = "RemovesInvalidAttributes.cs";
128+
project = project.AddDocument(
129+
removesInvalidAttributesFileName,
130+
File.ReadAllText(Path.Join(folder, removesInvalidAttributesFileName))).Project;
131+
project = project.AddDocument(
132+
"Model.cs",
133+
File.ReadAllText(Path.Join(folder, "Model.cs"))).Project;
134+
project = project.AddDocument(
135+
"RootClass.cs",
136+
File.ReadAllText(Path.Join(folder, "RootClass.cs"))).Project;
137+
var postProcessor = new TestPostProcessor("RootClass.cs", nonRootTypes: ["Sample.KeepMe"]);
138+
139+
var resultProject = await postProcessor.RemoveAsync(project);
140+
var doc= resultProject.Documents
141+
.Single(d => d.Name == removesInvalidAttributesFileName);
142+
var root = await doc.GetSyntaxRootAsync();
143+
var compilation = (CompilationUnitSyntax)root!;
144+
145+
var type = compilation
146+
.DescendantNodes()
147+
.OfType<BaseTypeDeclarationSyntax>()
148+
.SingleOrDefault(t => t.Identifier.Text == "KeepMe");
149+
Assert.IsNotNull(type, "The class 'KeepMe' should still exist.");
150+
151+
var attributes = type!.AttributeLists.SelectMany(a => a.Attributes).Select(al => al.ToString()).ToList();
152+
153+
// The invalid attribute should be removed
154+
CollectionAssert.DoesNotContain(attributes, "ModelReaderWriterBuildable");
155+
156+
// The class documentation should be preserved
157+
var classDocumentation = compilation.DescendantNodes()
158+
.OfType<ClassDeclarationSyntax>()
159+
.Single(c => c.Identifier.Text == "KeepMe")
160+
.GetLeadingTrivia()
161+
.Select(t => t.ToString())
162+
.SingleOrDefault(t => t.Trim().Contains("<summary>"));
163+
164+
StringAssert.Contains("Class docs that should be kept.", classDocumentation);
165+
}
166+
167+
[Test]
168+
public async Task DoesNotRemoveValidAttributes()
169+
{
170+
MockHelpers.LoadMockGenerator();
171+
var workspace = new AdhocWorkspace();
172+
var projectInfo = ProjectInfo.Create(
173+
ProjectId.CreateNewId(),
174+
VersionStamp.Create(),
175+
name: "TestProj",
176+
assemblyName: "TestProj",
177+
language: LanguageNames.CSharp)
178+
.WithMetadataReferences(new[]
179+
{
180+
MetadataReference.CreateFromFile(typeof(object).Assembly.Location),
181+
MetadataReference.CreateFromFile(typeof(ModelReaderWriterBuildableAttribute).Assembly.Location)
182+
});
183+
184+
var project = workspace.AddProject(projectInfo);
185+
var folder = Helpers.GetAssetFileOrDirectoryPath(false);
186+
const string doesNotRemoveValidAttributesFileName = "DoesNotRemoveValidAttributes.cs";
187+
project = project.AddDocument(
188+
doesNotRemoveValidAttributesFileName,
189+
File.ReadAllText(Path.Join(folder, doesNotRemoveValidAttributesFileName))).Project;
190+
project = project.AddDocument(
191+
"Model.cs",
192+
File.ReadAllText(Path.Join(folder, "Model.cs"))).Project;
193+
project = project.AddDocument(
194+
"RootClass.cs",
195+
File.ReadAllText(Path.Join(folder, "RootClass.cs"))).Project;
196+
var postProcessor = new TestPostProcessor("RootClass.cs");
197+
198+
var resultProject = await postProcessor.RemoveAsync(project);
199+
var doc= resultProject.Documents
200+
.Single(d => d.Name == doesNotRemoveValidAttributesFileName);
201+
var root = await doc.GetSyntaxRootAsync();
202+
var compilation = (CompilationUnitSyntax)root!;
203+
204+
var type = compilation
205+
.DescendantNodes()
206+
.OfType<BaseTypeDeclarationSyntax>()
207+
.SingleOrDefault(t => t.Identifier.Text == "KeepMe");
208+
Assert.IsNotNull(type, "The class 'KeepMe' should still exist.");
209+
210+
var attributes = type!.AttributeLists.SelectMany(a => a.Attributes).Select(al => al.ToString().Trim()).ToList();
211+
212+
// The valid attribute should be retained
213+
CollectionAssert.Contains(attributes, "ModelReaderWriterBuildable(typeof(Model))");
214+
215+
// The class documentation should be preserved
216+
var classDocumentation = compilation.DescendantNodes()
217+
.OfType<ClassDeclarationSyntax>()
218+
.Single(c => c.Identifier.Text == "KeepMe")
219+
.GetLeadingTrivia()
220+
.Select(t => t.ToString())
221+
.SingleOrDefault(t => t.Trim().Contains("<summary>"));
222+
223+
StringAssert.Contains("Class docs that should be kept.", classDocumentation);
224+
}
225+
105226
private class TestPostProcessor : PostProcessor
106227
{
107228
private readonly string _rootFile;
108229

109-
public TestPostProcessor(string rootFile) : base([])
230+
public TestPostProcessor(string rootFile, IEnumerable<string>? nonRootTypes = null) : base([], additionalNonRootTypeFullNames: nonRootTypes)
110231
{
111232
_rootFile = rootFile;
112233
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using System;
2+
using System.ClientModel.Primitives;
3+
using Sample.Models;
4+
5+
namespace Sample
6+
{
7+
/// <summary>
8+
/// Class docs that should be kept.
9+
/// </summary>
10+
[ModelReaderWriterBuildable(typeof(Model))]
11+
public class KeepMe
12+
{
13+
public void Foo() => Console.WriteLine(""hello"");
14+
}
15+
}

0 commit comments

Comments
 (0)