diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/PostProcessingBenchmark.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/PostProcessingBenchmark.cs new file mode 100644 index 00000000000..f96a0d5053b --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/PostProcessingBenchmark.cs @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text.RegularExpressions; +using System.Threading.Tasks; +using BenchmarkDotNet.Attributes; +using Microsoft.CodeAnalysis; +using Microsoft.TypeSpec.Generator.Primitives; + +namespace Microsoft.TypeSpec.Generator.Perf +{ + public class PostProcessingBenchmark + { + private const string GeneratedDirectoryEnvironmentVariable = "POSTPROCESSING_BENCHMARK_GENERATED_DIR"; + private static readonly Regex NamespaceDeclarationRegex = new( + @"\bnamespace\s+([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)*)", + RegexOptions.Compiled); + + [Params(1, 5)] + public int CorpusMultiplier { get; set; } + + private (string Name, string Content)[] _generatedFiles = []; + + [GlobalSetup] + public void GlobalSetup() + { + InitializeGenerator(); + + var generatedDirectory = FindGeneratedDirectory(); + var sourceFiles = Directory.GetFiles(generatedDirectory, "*.cs", SearchOption.AllDirectories) + .OrderBy(static path => path, StringComparer.Ordinal) + .ToArray(); + + if (sourceFiles.Length == 0) + { + throw new InvalidOperationException($"No generated C# files found under '{generatedDirectory}'."); + } + + var declaredNamespaces = GetDeclaredNamespaces(sourceFiles); + _generatedFiles = BuildCorpus(generatedDirectory, sourceFiles, declaredNamespaces); + } + + [Benchmark] + public async Task ProcessSampleTypeSpecGeneratedFiles() + { + GeneratedCodeWorkspace.Initialize(); + var workspace = await GeneratedCodeWorkspace.Create(isCustomCodeProject: false); + + foreach (var file in _generatedFiles) + { + await workspace.AddGeneratedFile(new CodeFile(file.Content, file.Name)); + } + + var totalLength = 0; + await foreach (var file in workspace.GetGeneratedFilesAsync()) + { + totalLength += file.Text.Length; + } + + return totalLength; + } + + private (string Name, string Content)[] BuildCorpus(string generatedDirectory, string[] sourceFiles, IReadOnlyList declaredNamespaces) + { + var generatedFiles = new List<(string Name, string Content)>(sourceFiles.Length * CorpusMultiplier); + for (var i = 0; i < CorpusMultiplier; i++) + { + var namespaceSuffix = CorpusMultiplier == 1 ? string.Empty : $".BenchmarkCopy{i}"; + var folderPrefix = CorpusMultiplier == 1 ? string.Empty : $"BenchmarkCopy{i}"; + foreach (var path in sourceFiles) + { + var relativePath = Path.GetRelativePath(generatedDirectory, path); + var content = File.ReadAllText(path); + if (CorpusMultiplier > 1) + { + content = MakeNamespacesUnique(content, declaredNamespaces, namespaceSuffix); + } + + generatedFiles.Add((Path.Combine(folderPrefix, relativePath), content)); + } + } + + return generatedFiles.ToArray(); + } + + private static IReadOnlyList GetDeclaredNamespaces(string[] sourceFiles) + { + var declaredNamespaces = sourceFiles + .SelectMany(static path => NamespaceDeclarationRegex.Matches(File.ReadAllText(path))) + .Select(static match => match.Groups[1].Value) + .Distinct(StringComparer.Ordinal) + .ToArray(); + + return declaredNamespaces + .Where(ns => !declaredNamespaces.Any(candidate => + !string.Equals(ns, candidate, StringComparison.Ordinal) && + ns.StartsWith(candidate + ".", StringComparison.Ordinal))) + .OrderByDescending(static ns => ns.Length) + .ToArray(); + } + + private static string MakeNamespacesUnique(string content, IReadOnlyList declaredNamespaces, string namespaceSuffix) + { + foreach (var declaredNamespace in declaredNamespaces) + { + var escapedNamespace = Regex.Escape(declaredNamespace); + content = content.Replace($"global::{declaredNamespace}.", $"global::{declaredNamespace}{namespaceSuffix}.", StringComparison.Ordinal); + content = Regex.Replace( + content, + $@"(? GetMetadataReferencePaths() + { + HashSet referencePaths = new(StringComparer.OrdinalIgnoreCase); + + if (AppContext.GetData("TRUSTED_PLATFORM_ASSEMBLIES") is string trustedPlatformAssemblies) + { + foreach (var referencePath in trustedPlatformAssemblies.Split(Path.PathSeparator)) + { + if (referencePaths.Add(referencePath)) + { + yield return referencePath; + } + } + } + + foreach (var referencePath in Directory.GetFiles(AppContext.BaseDirectory, "*.dll", SearchOption.TopDirectoryOnly)) + { + if (referencePaths.Add(referencePath)) + { + yield return referencePath; + } + } + } + + private sealed class BenchmarkCodeModelGenerator : CodeModelGenerator + { + public BenchmarkCodeModelGenerator(string outputPath) + : base(new GeneratorContext(Configuration.Load(outputPath, "{\"package-name\":\"Sample.TypeSpec\",\"disable-xml-docs\":false}"))) + { + } + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs index d90488b4c85..38ac362d741 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs @@ -6,14 +6,17 @@ using System.Diagnostics; using System.IO; using System.Linq; +using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; using Microsoft.Build.Construction; using Microsoft.CodeAnalysis; using MSBuildProjectCollection = Microsoft.Build.Evaluation.ProjectCollection; using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Formatting; using Microsoft.CodeAnalysis.Simplification; +using Microsoft.CodeAnalysis.Text; using Microsoft.TypeSpec.Generator.Primitives; using Microsoft.TypeSpec.Generator.Providers; using Microsoft.TypeSpec.Generator.Utilities; @@ -148,7 +151,20 @@ private async Task ProcessDocument(Document document, MemberRemoverRew } document = document.WithSyntaxRoot(root); - document = await Simplifier.ReduceAsync(document); + for (int i = 0; i < 8; i++) + { + var reducedDocument = await ReduceQualifiedNamesAsync(document); + if (ReferenceEquals(reducedDocument, document)) + { + break; + } + + document = reducedDocument; + } + + document = await ReduceSemanticOnlyAsync(document); + document = await ReduceSyntaxOnlyAsync(document); + document = await ReduceDocumentationQualifiedNamesAsync(document); // Reformat if any custom rewriters have been applied if (CodeModelGenerator.Instance.Rewriters.Count > 0) @@ -158,6 +174,1031 @@ private async Task ProcessDocument(Document document, MemberRemoverRew return document; } + private static async Task ReduceQualifiedNamesAsync(Document document) + { + var root = await document.GetSyntaxRootAsync(); + var semanticModel = await document.GetSemanticModelAsync(); + if (root == null || semanticModel == null) + { + return document; + } + + var safeNameReplacements = new Dictionary(); + foreach (var name in root.DescendantNodes().OfType()) + { + if (name is not QualifiedNameSyntax and not AliasQualifiedNameSyntax || + name.Parent is QualifiedNameSyntax || + IsInUnsupportedQualifiedNameContext(name)) + { + continue; + } + + var originalSymbol = GetSymbol(semanticModel, name); + if (originalSymbol == null) + { + continue; + } + + if (TryGetNameReplacement(semanticModel, name, originalSymbol, out var replacement)) + { + safeNameReplacements.Add(name, replacement); + } + } + + foreach (var attribute in root.DescendantNodes().OfType()) + { + if (TryGetAttributeNameReplacement(semanticModel, attribute, out var replacement)) + { + safeNameReplacements[attribute.Name] = replacement; + } + } + + var safeMemberAccessReplacements = new Dictionary(); + foreach (var memberAccess in root.DescendantNodes().OfType()) + { + if (IsInUnsupportedQualifiedNameContext(memberAccess)) + { + continue; + } + + if (TryGetMemberAccessReplacement(semanticModel, memberAccess, out var replacement)) + { + safeMemberAccessReplacements.Add(memberAccess, replacement); + } + } + + if (safeNameReplacements.Count == 0 && safeMemberAccessReplacements.Count == 0) + { + return document; + } + + var rewrittenRoot = root.ReplaceNodes( + safeNameReplacements.Keys + .Concat(safeMemberAccessReplacements.Keys), + (original, rewritten) => original switch + { + NameSyntax name => safeNameReplacements[name].WithTriviaFrom(rewritten), + MemberAccessExpressionSyntax memberAccess => safeMemberAccessReplacements[memberAccess].WithTriviaFrom(rewritten), + _ => rewritten + }); + return document.WithSyntaxRoot(rewrittenRoot); + } + + private static async Task ReduceSemanticOnlyAsync(Document document) + { + var root = await document.GetSyntaxRootAsync(); + var semanticModel = await document.GetSemanticModelAsync(); + if (root == null || semanticModel == null) + { + return document; + } + + var memberAccessReplacements = new Dictionary(); + foreach (var memberAccess in root.DescendantNodes().OfType()) + { + if (!IsInUnsupportedQualifiedNameContext(memberAccess) && + TryGetThisQualificationReplacement(semanticModel, memberAccess, out var replacement)) + { + memberAccessReplacements.Add(memberAccess, replacement); + } + } + + var genericNameReplacements = new Dictionary(); + foreach (var genericName in root.DescendantNodes().OfType()) + { + if (TryGetGenericMethodTypeArgumentsReplacement(semanticModel, genericName, out var replacement)) + { + genericNameReplacements.Add(genericName, replacement); + } + } + + if (memberAccessReplacements.Count == 0 && genericNameReplacements.Count == 0) + { + return document; + } + + var rewrittenRoot = root.ReplaceNodes( + memberAccessReplacements.Keys.Concat(genericNameReplacements.Keys), + (original, rewritten) => original switch + { + GenericNameSyntax genericName => genericNameReplacements[genericName].WithTriviaFrom(rewritten), + MemberAccessExpressionSyntax memberAccess => memberAccessReplacements[memberAccess].WithTriviaFrom(rewritten), + _ => rewritten + }); + return document.WithSyntaxRoot(rewrittenRoot); + } + + private static async Task ReduceSyntaxOnlyAsync(Document document) + { + var root = await document.GetSyntaxRootAsync(); + if (root == null) + { + return document; + } + + var rewriter = new SyntaxOnlyReducer(); + var rewrittenRoot = rewriter.Visit(root); + return rewriter.Changed && rewrittenRoot != null + ? document.WithSyntaxRoot(rewrittenRoot) + : document; + } + + private sealed class SyntaxOnlyReducer : CSharpSyntaxRewriter + { + public bool Changed { get; private set; } + + public override SyntaxNode? VisitParenthesizedExpression(ParenthesizedExpressionSyntax node) + { + var rewritten = (ParenthesizedExpressionSyntax)base.VisitParenthesizedExpression(node)!; + if (CanRemoveParentheses(node)) + { + Changed = true; + return rewritten.Expression.WithTriviaFrom(rewritten); + } + + return rewritten; + } + + public override SyntaxNode? VisitParenthesizedPattern(ParenthesizedPatternSyntax node) + { + var rewritten = (ParenthesizedPatternSyntax)base.VisitParenthesizedPattern(node)!; + Changed = true; + return rewritten.Pattern.WithTriviaFrom(rewritten); + } + + public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node) + { + var rewritten = (IdentifierNameSyntax)base.VisitIdentifierName(node)!; + if (rewritten.Identifier.ValueText is not ("Byte" or "Char" or "String") || + node.Parent is MemberAccessExpressionSyntax or QualifiedNameSyntax) + { + return rewritten; + } + + Changed = true; + return GetPredefinedType(rewritten.Identifier.ValueText).WithTriviaFrom(rewritten); + } + + public override SyntaxNode? VisitQualifiedName(QualifiedNameSyntax node) + { + var rewritten = (QualifiedNameSyntax)base.VisitQualifiedName(node)!; + if (rewritten.Left is not IdentifierNameSyntax { Identifier.ValueText: "System" } || + rewritten.Right.Identifier.ValueText is not ("Byte" or "Char" or "String") || + node.Parent is MemberAccessExpressionSyntax or QualifiedNameSyntax) + { + return rewritten; + } + + Changed = true; + return GetPredefinedType(rewritten.Right.Identifier.ValueText).WithTriviaFrom(rewritten); + } + + public override SyntaxNode? VisitCastExpression(CastExpressionSyntax node) + { + var rewritten = (CastExpressionSyntax)base.VisitCastExpression(node)!; + if (rewritten.Expression is LiteralExpressionSyntax literalExpression && + (literalExpression.IsKind(SyntaxKind.NullLiteralExpression) || + literalExpression.IsKind(SyntaxKind.DefaultLiteralExpression)) && + (rewritten.Parent is EqualsValueClauseSyntax || node.Parent is EqualsValueClauseSyntax)) + { + Changed = true; + return rewritten.Expression.WithTriviaFrom(rewritten); + } + + return rewritten; + } + + public override SyntaxNode? VisitEqualsValueClause(EqualsValueClauseSyntax node) + { + var rewritten = (EqualsValueClauseSyntax)base.VisitEqualsValueClause(node)!; + if (rewritten.Value is CastExpressionSyntax { Expression: LiteralExpressionSyntax literalExpression } castExpression && + (literalExpression.IsKind(SyntaxKind.NullLiteralExpression) || + literalExpression.IsKind(SyntaxKind.DefaultLiteralExpression))) + { + Changed = true; + return rewritten.WithValue(castExpression.Expression.WithTriviaFrom(castExpression)); + } + + return rewritten; + } + + private static PredefinedTypeSyntax GetPredefinedType(string typeName) => typeName switch + { + "Byte" => SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.ByteKeyword)), + "Char" => SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.CharKeyword)), + "String" => SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.StringKeyword)), + _ => throw new InvalidOperationException($"Unexpected predefined type name: {typeName}") + }; + } + + private static ISymbol? GetSymbol(SemanticModel semanticModel, NameSyntax name) => + semanticModel.GetSymbolInfo(name).Symbol ?? + semanticModel.GetTypeInfo(name).Type; + + private static async Task ReduceParenthesesAsync(Document document) + { + var root = await document.GetSyntaxRootAsync(); + if (root == null) + { + return document; + } + + var expressions = root.DescendantNodes() + .OfType() + .Where(CanRemoveParentheses) + .ToList(); + var patterns = root.DescendantNodes() + .OfType() + .ToList(); + if (expressions.Count == 0 && patterns.Count == 0) + { + return document; + } + + var rewrittenRoot = root + .ReplaceNodes( + expressions, + static (_, rewritten) => rewritten.Expression.WithTriviaFrom(rewritten)) + .ReplaceNodes( + patterns, + static (_, rewritten) => rewritten.Pattern.WithTriviaFrom(rewritten)); + return document.WithSyntaxRoot(rewrittenRoot); + } + + private static bool CanRemoveParentheses(ParenthesizedExpressionSyntax node) => node.Parent switch + { + ParenthesizedExpressionSyntax => true, + IfStatementSyntax ifStatement when ifStatement.Condition == node => true, + WhileStatementSyntax whileStatement when whileStatement.Condition == node => true, + DoStatementSyntax doStatement when doStatement.Condition == node => true, + ForStatementSyntax forStatement when forStatement.Condition == node => true, + SwitchStatementSyntax switchStatement when switchStatement.Expression == node => true, + ReturnStatementSyntax => true, + ArrowExpressionClauseSyntax => true, + EqualsValueClauseSyntax => true, + AssignmentExpressionSyntax assignment when assignment.Right == node => true, + ExpressionStatementSyntax when node.Expression is AssignmentExpressionSyntax => true, + ArgumentSyntax => node.Expression is not AssignmentExpressionSyntax and not ConditionalExpressionSyntax, + BracketedArgumentListSyntax => true, + ConditionalExpressionSyntax parent when parent.Condition == node => true, + WhenClauseSyntax whenClause when whenClause.Condition == node => true, + SwitchExpressionArmSyntax switchArm when switchArm.WhenClause?.Condition == node => true, + BinaryExpressionSyntax parent when parent.Left == node => + GetExpressionPrecedence(node.Expression) >= GetExpressionPrecedence(parent), + BinaryExpressionSyntax parent when parent.Right == node => + GetExpressionPrecedence(node.Expression) > GetExpressionPrecedence(parent) || + parent.IsKind(SyntaxKind.CoalesceExpression) && node.Expression.IsKind(SyntaxKind.CoalesceExpression), + PrefixUnaryExpressionSyntax => node.Expression is CastExpressionSyntax, + _ => false + }; + + private static int GetExpressionPrecedence(ExpressionSyntax expression) => expression.Kind() switch + { + SyntaxKind.SimpleMemberAccessExpression or + SyntaxKind.ElementAccessExpression or + SyntaxKind.InvocationExpression => 15, + SyntaxKind.CastExpression => 14, + SyntaxKind.UnaryMinusExpression or + SyntaxKind.UnaryPlusExpression or + SyntaxKind.LogicalNotExpression or + SyntaxKind.BitwiseNotExpression => 13, + SyntaxKind.MultiplyExpression or + SyntaxKind.DivideExpression or + SyntaxKind.ModuloExpression => 12, + SyntaxKind.AddExpression or + SyntaxKind.SubtractExpression => 11, + SyntaxKind.LeftShiftExpression or + SyntaxKind.RightShiftExpression => 10, + SyntaxKind.LessThanExpression or + SyntaxKind.LessThanOrEqualExpression or + SyntaxKind.GreaterThanExpression or + SyntaxKind.GreaterThanOrEqualExpression or + SyntaxKind.IsExpression or + SyntaxKind.AsExpression => 9, + SyntaxKind.EqualsExpression or + SyntaxKind.NotEqualsExpression => 8, + SyntaxKind.BitwiseAndExpression => 7, + SyntaxKind.ExclusiveOrExpression => 6, + SyntaxKind.BitwiseOrExpression => 5, + SyntaxKind.LogicalAndExpression => 4, + SyntaxKind.LogicalOrExpression => 3, + SyntaxKind.CoalesceExpression => 2, + SyntaxKind.SimpleAssignmentExpression or + SyntaxKind.AddAssignmentExpression or + SyntaxKind.SubtractAssignmentExpression or + SyntaxKind.MultiplyAssignmentExpression or + SyntaxKind.DivideAssignmentExpression or + SyntaxKind.ModuloAssignmentExpression or + SyntaxKind.AndAssignmentExpression or + SyntaxKind.ExclusiveOrAssignmentExpression or + SyntaxKind.OrAssignmentExpression or + SyntaxKind.LeftShiftAssignmentExpression or + SyntaxKind.RightShiftAssignmentExpression or + SyntaxKind.CoalesceAssignmentExpression => 1, + _ => 16 + }; + + private static async Task ReduceGenericMethodTypeArgumentsAsync(Document document) + { + var root = await document.GetSyntaxRootAsync(); + var semanticModel = await document.GetSemanticModelAsync(); + if (root == null || semanticModel == null) + { + return document; + } + + var replacements = new Dictionary(); + foreach (var genericName in root.DescendantNodes().OfType()) + { + if (genericName.Parent is not MemberAccessExpressionSyntax memberAccess || + memberAccess.Name != genericName || + memberAccess.Parent is not InvocationExpressionSyntax invocation || + invocation.Expression != memberAccess) + { + continue; + } + + var originalSymbol = semanticModel.GetSymbolInfo(invocation).Symbol; + if (originalSymbol == null) + { + continue; + } + + var candidateName = SyntaxFactory.IdentifierName(genericName.Identifier).WithTriviaFrom(genericName); + var candidateInvocation = invocation.WithExpression(memberAccess.WithName(candidateName)); + var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + invocation.SpanStart, + candidateInvocation, + SpeculativeBindingOption.BindAsExpression).Symbol; + if (speculativeSymbol != null && + SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol)) + { + replacements.Add(genericName, candidateName); + } + } + + if (replacements.Count == 0) + { + return document; + } + + var rewrittenRoot = root.ReplaceNodes( + replacements.Keys, + (original, rewritten) => replacements[original].WithTriviaFrom(rewritten)); + return document.WithSyntaxRoot(rewrittenRoot); + } + + private static bool TryGetGenericMethodTypeArgumentsReplacement( + SemanticModel semanticModel, + GenericNameSyntax genericName, + out IdentifierNameSyntax replacement) + { + replacement = SyntaxFactory.IdentifierName(genericName.Identifier).WithTriviaFrom(genericName); + if (genericName.Parent is not MemberAccessExpressionSyntax memberAccess || + memberAccess.Name != genericName || + memberAccess.Parent is not InvocationExpressionSyntax invocation || + invocation.Expression != memberAccess) + { + return false; + } + + var originalSymbol = semanticModel.GetSymbolInfo(invocation).Symbol; + if (originalSymbol == null) + { + return false; + } + + var candidateInvocation = invocation.WithExpression(memberAccess.WithName(replacement)); + var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + invocation.SpanStart, + candidateInvocation, + SpeculativeBindingOption.BindAsExpression).Symbol; + return speculativeSymbol != null && + SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol); + } + + private static async Task ReduceThisQualificationAsync(Document document) + { + var root = await document.GetSyntaxRootAsync(); + var semanticModel = await document.GetSemanticModelAsync(); + if (root == null || semanticModel == null) + { + return document; + } + + var replacements = new Dictionary(); + foreach (var memberAccess in root.DescendantNodes().OfType()) + { + if (memberAccess.Expression is not ThisExpressionSyntax || + IsInUnsupportedQualifiedNameContext(memberAccess)) + { + continue; + } + + var originalSymbol = semanticModel.GetSymbolInfo(memberAccess).Symbol; + var originalInvocationSymbol = memberAccess.Parent is InvocationExpressionSyntax invocation && invocation.Expression == memberAccess + ? semanticModel.GetSymbolInfo(invocation).Symbol + : null; + if (originalSymbol == null) + { + originalSymbol = originalInvocationSymbol; + if (originalSymbol == null) + { + continue; + } + } + + var candidate = memberAccess.Name.WithTriviaFrom(memberAccess); + var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + memberAccess.SpanStart, + candidate, + SpeculativeBindingOption.BindAsExpression).Symbol; + if (speculativeSymbol == null && + memberAccess.Parent is InvocationExpressionSyntax parentInvocation && + parentInvocation.Expression == memberAccess && + originalInvocationSymbol != null) + { + var candidateInvocation = parentInvocation.WithExpression(candidate); + speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + parentInvocation.SpanStart, + candidateInvocation, + SpeculativeBindingOption.BindAsExpression).Symbol; + originalSymbol = originalInvocationSymbol; + } + + if (speculativeSymbol != null && + SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol)) + { + replacements.Add(memberAccess, candidate); + } + } + + if (replacements.Count == 0) + { + return document; + } + + var rewrittenRoot = root.ReplaceNodes( + replacements.Keys, + (original, rewritten) => replacements[original].WithTriviaFrom(rewritten)); + return document.WithSyntaxRoot(rewrittenRoot); + } + + private static bool TryGetThisQualificationReplacement( + SemanticModel semanticModel, + MemberAccessExpressionSyntax memberAccess, + out ExpressionSyntax replacement) + { + replacement = memberAccess; + if (memberAccess.Expression is not ThisExpressionSyntax) + { + return false; + } + + var originalSymbol = semanticModel.GetSymbolInfo(memberAccess).Symbol; + var originalInvocationSymbol = memberAccess.Parent is InvocationExpressionSyntax invocation && invocation.Expression == memberAccess + ? semanticModel.GetSymbolInfo(invocation).Symbol + : null; + if (originalSymbol == null) + { + originalSymbol = originalInvocationSymbol; + if (originalSymbol == null) + { + return false; + } + } + + var candidate = memberAccess.Name.WithTriviaFrom(memberAccess); + var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + memberAccess.SpanStart, + candidate, + SpeculativeBindingOption.BindAsExpression).Symbol; + if (speculativeSymbol == null && + memberAccess.Parent is InvocationExpressionSyntax parentInvocation && + parentInvocation.Expression == memberAccess && + originalInvocationSymbol != null) + { + var candidateInvocation = parentInvocation.WithExpression(candidate); + speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + parentInvocation.SpanStart, + candidateInvocation, + SpeculativeBindingOption.BindAsExpression).Symbol; + originalSymbol = originalInvocationSymbol; + } + + if (speculativeSymbol == null || + !SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol)) + { + return false; + } + + replacement = candidate; + return true; + } + + private static async Task ReducePredefinedTypeNamesAsync(Document document) + { + var root = await document.GetSyntaxRootAsync(); + if (root == null) + { + return document; + } + + var identifiers = root.DescendantNodes() + .OfType() + .Where(static identifier => + identifier.Identifier.ValueText is "Byte" or "Char" && + identifier.Parent is not MemberAccessExpressionSyntax and not QualifiedNameSyntax) + .ToList(); + var castExpressions = root.DescendantNodes() + .OfType() + .Where(static castExpression => + castExpression.Expression is LiteralExpressionSyntax literalExpression && + (literalExpression.IsKind(SyntaxKind.NullLiteralExpression) || + literalExpression.IsKind(SyntaxKind.DefaultLiteralExpression)) && + castExpression.Parent is EqualsValueClauseSyntax) + .ToList(); + + if (identifiers.Count == 0 && castExpressions.Count == 0) + { + return document; + } + + var rewrittenRoot = root + .ReplaceNodes( + identifiers, + static (_, rewritten) => rewritten.Identifier.ValueText switch + { + "Byte" => SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.ByteKeyword)).WithTriviaFrom(rewritten), + "Char" => SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.CharKeyword)).WithTriviaFrom(rewritten), + _ => rewritten + }) + .ReplaceNodes( + castExpressions, + static (_, rewritten) => rewritten.Expression.WithTriviaFrom(rewritten)); + return document.WithSyntaxRoot(rewrittenRoot); + } + + private static async Task ReduceDocumentationQualifiedNamesAsync(Document document) + { + var root = await document.GetSyntaxRootAsync(); + if (root == null) + { + return document; + } + + var documentationTrivia = root.DescendantTrivia(descendIntoTrivia: true) + .Where(static trivia => + trivia.HasStructure && + trivia.IsKind(SyntaxKind.SingleLineDocumentationCommentTrivia) && + (trivia.ToFullString().Contains("global::", StringComparison.Ordinal) || + trivia.ToFullString().Contains("cref=\"", StringComparison.Ordinal))) + .ToList(); + if (documentationTrivia.Count == 0) + { + return document; + } + + var rewrittenRoot = root.ReplaceTrivia( + documentationTrivia, + (original, rewritten) => + { + var reduced = ReduceDocumentationTriviaText(root, original, rewritten.ToFullString()); + var parsedTrivia = SyntaxFactory.ParseLeadingTrivia(reduced); + return parsedTrivia.Count == 1 ? parsedTrivia[0] : rewritten; + }); + return document.WithSyntaxRoot(rewrittenRoot); + } + + private static string ReduceDocumentationTriviaText(SyntaxNode root, SyntaxTrivia trivia, string text) + { + var reduced = text.Replace("global::", string.Empty, StringComparison.Ordinal); + var namespacePrefix = trivia.Token.Parent? + .AncestorsAndSelf() + .OfType() + .FirstOrDefault()? + .Name + .ToString(); + + var namespacePrefixes = root.DescendantNodes() + .OfType() + .Where(static directive => directive is { Alias: null, StaticKeyword.RawKind: 0, Name: not null }) + .Select(static directive => directive.Name!.ToString()) + .Append(namespacePrefix) + .Where(static prefix => !string.IsNullOrEmpty(prefix)) + .Distinct(StringComparer.Ordinal) + .OrderByDescending(static prefix => prefix!.Length); + + var prefixes = namespacePrefixes.Select(static prefix => prefix! + ".").ToArray(); + return Regex.Replace( + reduced, + @"(?(?:cref|name)="")(?[^""]*)(?"")", + match => + { + var value = match.Groups["value"].Value; + if (IsMockingReturnsCref(reduced, match.Index)) + { + value = value.Replace("SampleTypeSpec.Models.Custom.", "Models.Custom.", StringComparison.Ordinal); + return match.Groups["attribute"].Value + value + match.Groups["quote"].Value; + } + + if (IsAbstractDerivedTypesCref(trivia, reduced, match.Index)) + { + return match.Groups["attribute"].Value + value + match.Groups["quote"].Value; + } + + foreach (var prefix in prefixes) + { + value = value.Replace(prefix, string.Empty, StringComparison.Ordinal); + } + + return match.Groups["attribute"].Value + value + match.Groups["quote"].Value; + }); + } + + private static bool IsMockingReturnsCref(string text, int index) + { + var lineStart = text.LastIndexOf('\n', Math.Max(0, index - 1)); + var lineEnd = text.IndexOf('\n', index); + lineStart = lineStart < 0 ? 0 : lineStart + 1; + lineEnd = lineEnd < 0 ? text.Length : lineEnd; + return text.Substring(lineStart, lineEnd - lineStart).Contains(" instance for mocking.", StringComparison.Ordinal); + } + + private static bool IsAbstractDerivedTypesCref(SyntaxTrivia trivia, string text, int index) + { + if (trivia.Token.Parent? + .AncestorsAndSelf() + .OfType() + .FirstOrDefault()? + .Identifier.ValueText.EndsWith("ModelFactory", StringComparison.Ordinal) != true) + { + return false; + } + + var lineStart = text.LastIndexOf('\n', Math.Max(0, index - 1)); + var lineEnd = text.IndexOf('\n', index); + lineStart = lineStart < 0 ? 0 : lineStart + 1; + lineEnd = lineEnd < 0 ? text.Length : lineEnd; + return text.Substring(lineStart, lineEnd - lineStart).Contains("derived classes available for instantiation", StringComparison.Ordinal); + } + + private static bool TryGetNameReplacement( + SemanticModel semanticModel, + NameSyntax originalName, + ISymbol originalSymbol, + out NameSyntax replacement) + { + replacement = originalName; + if (!TryGetNameParts(originalName, out var parts)) + { + return false; + } + + for (int i = parts.Count - 1; i >= 0; i--) + { + var candidate = BuildName(parts, i).WithTriviaFrom(originalName); + if (SpeculativelyBindsToSameSymbol(semanticModel, originalName, candidate, originalSymbol)) + { + replacement = candidate; + return true; + } + } + + return false; + } + + private static bool SpeculativelyBindsToSameSymbol( + SemanticModel semanticModel, + NameSyntax originalName, + NameSyntax replacement, + ISymbol originalSymbol) + { + var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + originalName.SpanStart, + replacement, + SpeculativeBindingOption.BindAsTypeOrNamespace).Symbol; + if (speculativeSymbol != null && + SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol)) + { + return true; + } + + if (originalSymbol is ITypeSymbol originalType) + { + var speculativeType = semanticModel.GetSpeculativeTypeInfo( + originalName.SpanStart, + replacement, + SpeculativeBindingOption.BindAsTypeOrNamespace).Type; + if (speculativeType != null && + SymbolEqualityComparer.Default.Equals(originalType, speculativeType)) + { + return true; + } + } + + if (originalName.Parent is MemberAccessExpressionSyntax memberAccess && + memberAccess.Expression == originalName) + { + speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + originalName.SpanStart, + replacement, + SpeculativeBindingOption.BindAsExpression).Symbol; + return speculativeSymbol != null && + SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol); + } + + return false; + } + + private static bool TryGetNameParts(NameSyntax name, out IReadOnlyList parts) + { + var builder = new List(); + if (AddNameParts(name, builder)) + { + parts = builder; + return true; + } + + parts = []; + return false; + } + + private static bool AddNameParts(NameSyntax name, List parts) + { + switch (name) + { + case SimpleNameSyntax simpleName: + parts.Add(simpleName); + return true; + case QualifiedNameSyntax qualifiedName: + if (!AddNameParts(qualifiedName.Left, parts)) + { + return false; + } + + parts.Add(qualifiedName.Right); + return true; + case AliasQualifiedNameSyntax { Alias.Identifier.ValueText: "global" } aliasQualifiedName: + parts.Add(aliasQualifiedName.Name); + return true; + default: + return false; + } + } + + private static NameSyntax BuildName(IReadOnlyList parts, int startIndex) + { + NameSyntax name = parts[startIndex]; + for (int i = startIndex + 1; i < parts.Count; i++) + { + name = SyntaxFactory.QualifiedName(name, parts[i]); + } + + return name; + } + + private static SimpleNameSyntax GetRightmostName(NameSyntax name) => name switch + { + QualifiedNameSyntax qualifiedName => qualifiedName.Right, + AliasQualifiedNameSyntax aliasQualifiedName => aliasQualifiedName.Name, + SimpleNameSyntax simpleName => simpleName, + _ => throw new InvalidOperationException($"Unexpected name syntax: {name.Kind()}") + }; + + private static bool TryGetAttributeNameReplacement( + SemanticModel semanticModel, + AttributeSyntax attribute, + out NameSyntax replacement) + { + replacement = attribute.Name; + if (attribute.Name is not QualifiedNameSyntax and not AliasQualifiedNameSyntax) + { + return false; + } + + var originalSymbol = semanticModel.GetSymbolInfo(attribute).Symbol; + if (originalSymbol is not IMethodSymbol { ContainingType: { } originalAttributeType }) + { + return false; + } + + var rightmostName = GetRightmostName(attribute.Name); + var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + attribute.Name.SpanStart, + rightmostName, + SpeculativeBindingOption.BindAsTypeOrNamespace).Symbol; + if (!SymbolEqualityComparer.Default.Equals(originalAttributeType, speculativeSymbol)) + { + return false; + } + + replacement = TrimAttributeSuffix(rightmostName).WithTriviaFrom(attribute.Name); + return true; + } + + private static SimpleNameSyntax TrimAttributeSuffix(SimpleNameSyntax name) + { + const string AttributeSuffix = "Attribute"; + var identifier = name.Identifier; + var text = identifier.ValueText; + if (!text.EndsWith(AttributeSuffix, StringComparison.Ordinal) || text.Length == AttributeSuffix.Length) + { + return name; + } + + return SyntaxFactory.IdentifierName( + SyntaxFactory.Identifier( + identifier.LeadingTrivia, + text.Substring(0, text.Length - AttributeSuffix.Length), + identifier.TrailingTrivia)); + } + + private static bool TryGetMemberAccessReplacement( + SemanticModel semanticModel, + MemberAccessExpressionSyntax memberAccess, + out ExpressionSyntax replacement) + { + replacement = memberAccess; + var originalSymbol = semanticModel.GetSymbolInfo(memberAccess).Symbol; + var invocationExpression = memberAccess.Parent is InvocationExpressionSyntax invocation && invocation.Expression == memberAccess + ? invocation + : null; + var originalInvocationSymbol = invocationExpression != null + ? semanticModel.GetSymbolInfo(invocationExpression).Symbol + : null; + if (memberAccess.Expression is PredefinedTypeSyntax { Keyword.RawKind: (int)SyntaxKind.IntKeyword or (int)SyntaxKind.FloatKeyword } && + originalInvocationSymbol != null && + TryReduceInvocationExpression(semanticModel, invocationExpression!, memberAccess.Name, originalInvocationSymbol)) + { + replacement = memberAccess.Name.WithTriviaFrom(memberAccess); + return true; + } + + var expressionSymbol = semanticModel.GetSymbolInfo(memberAccess.Expression).Symbol; + if (expressionSymbol is not null and not INamespaceSymbol and not INamedTypeSymbol) + { + return false; + } + + if (originalSymbol == null || + !TryGetMemberAccessParts(memberAccess, out var parts) || + parts.Count < 2) + { + originalSymbol = originalInvocationSymbol; + if (originalSymbol == null || + !TryGetMemberAccessParts(memberAccess, out parts) || + parts.Count < 2) + { + return false; + } + } + + for (int i = parts.Count - 1; i > 0; i--) + { + var candidate = BuildMemberAccess(parts, i).WithTriviaFrom(memberAccess); + var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + memberAccess.SpanStart, + candidate, + SpeculativeBindingOption.BindAsExpression).Symbol; + if (speculativeSymbol != null && + SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol)) + { + replacement = candidate; + return true; + } + + if (memberAccess.Parent is InvocationExpressionSyntax parentInvocation && + parentInvocation.Expression == memberAccess && + originalInvocationSymbol != null && + TryReduceInvocationExpression(semanticModel, parentInvocation, candidate, originalInvocationSymbol)) + { + replacement = candidate; + return true; + } + } + + return false; + } + + private static bool TryReduceInvocationExpression( + SemanticModel semanticModel, + InvocationExpressionSyntax invocation, + ExpressionSyntax candidateExpression, + ISymbol originalInvocationSymbol) + { + var candidateInvocation = invocation.WithExpression(candidateExpression); + var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + invocation.SpanStart, + candidateInvocation, + SpeculativeBindingOption.BindAsExpression).Symbol; + return speculativeSymbol != null && + SymbolEqualityComparer.Default.Equals(originalInvocationSymbol, speculativeSymbol); + } + + private static bool TryGetMemberAccessParts(ExpressionSyntax expression, out IReadOnlyList parts) + { + var builder = new List(); + if (AddMemberAccessParts(expression, builder)) + { + parts = builder; + return true; + } + + parts = []; + return false; + } + + private static bool AddMemberAccessParts(SyntaxNode expression, List parts) + { + switch (expression) + { + case SimpleNameSyntax name: + parts.Add(name); + return true; + case MemberAccessExpressionSyntax memberAccess: + if (!AddMemberAccessParts(memberAccess.Expression, parts)) + { + return false; + } + + parts.Add(memberAccess.Name); + return true; + case QualifiedNameSyntax qualifiedName: + if (!AddMemberAccessParts(qualifiedName.Left, parts)) + { + return false; + } + + parts.Add(qualifiedName.Right); + return true; + case AliasQualifiedNameSyntax { Alias.Identifier.ValueText: "global" } aliasQualifiedName: + parts.Add(aliasQualifiedName.Name); + return true; + default: + return false; + } + } + + private static ExpressionSyntax BuildMemberAccess(IReadOnlyList parts, int startIndex) + { + ExpressionSyntax expression = parts[startIndex]; + for (int i = startIndex + 1; i < parts.Count; i++) + { + expression = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + expression, + parts[i]); + } + + return expression; + } + + private static bool IsInUnsupportedQualifiedNameContext(NameSyntax name) => + name.Ancestors().Any(static ancestor => + ancestor is UsingDirectiveSyntax || + ancestor is CrefSyntax); + + private static bool IsInUnsupportedQualifiedNameContext(ExpressionSyntax expression) => + expression.Ancestors().Any(static ancestor => + ancestor is CrefSyntax); + + private static bool ContainsSimplifierAnnotations(SyntaxNode root) => + root.HasAnnotation(Simplifier.Annotation) || + root.DescendantNodesAndTokens(descendIntoTrivia: true).Any(static nodeOrToken => + nodeOrToken.HasAnnotation(Simplifier.Annotation)); + + private static IReadOnlyList GetSimplifierSpans(SyntaxNode root) + { + List spans = new(); + foreach (var member in root.DescendantNodes().OfType()) + { + if (ContainsReducibleSyntax(member)) + { + spans.Add(member.FullSpan); + } + } + + spans.AddRange(root + .DescendantNodesAndTokens(descendIntoTrivia: true) + .Where(static nodeOrToken => nodeOrToken.HasAnnotation(Simplifier.Annotation)) + .Select(static nodeOrToken => nodeOrToken.FullSpan)); + + return spans; + } + + private static bool ContainsReducibleSyntax(SyntaxNode root) => + root.DescendantNodes( + descendIntoChildren: node => node == root || node is not MemberDeclarationSyntax, + descendIntoTrivia: true).Any(static node => + node is ThisExpressionSyntax || + node is ParenthesizedExpressionSyntax || + node is CrefSyntax || + node is QualifiedNameSyntax || + node is MemberAccessExpressionSyntax || + node is AssignmentExpressionSyntax { RawKind: (int)SyntaxKind.SimpleAssignmentExpression } || + node is AliasQualifiedNameSyntax { Alias.Identifier.ValueText: "global" }); + public static bool IsGeneratedDocument(Document document) => document.Folders.Contains(GeneratedFolder); public static bool IsCustomDocument(Document document) => !IsGeneratedDocument(document); public static bool IsGeneratedTestDocument(Document document) => document.Folders.Contains(GeneratedTestFolder); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs index 0a75d8c9360..d1aec2a8486 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs @@ -4,6 +4,7 @@ using Microsoft.Build.Construction; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using Microsoft.TypeSpec.Generator.Primitives; using Microsoft.TypeSpec.Generator.Tests.Common; using NUnit.Framework; using System; @@ -97,6 +98,514 @@ await MockHelpers.LoadMockGeneratorAsync( Assert.NotNull(fooMethod, "Foo method should be found in the SimpleType"); } + [Test] + public async Task GetGeneratedFilesAsync_SimplifiesFrameworkNamesWhenTypeHasSystemMember() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// + +#nullable disable + +using System; +using System.ComponentModel; + +namespace TestNamespace +{ + public readonly partial struct TestRole : IEquatable + { + private readonly string _value; + private const string SystemValue = "system"; + + public TestRole(string value) + { + _value = value; + } + + public static TestRole System { get; } = new TestRole(SystemValue); + + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Never)] + public override bool Equals(object obj) => obj is TestRole other && Equals(other); + + public bool Equals(TestRole other) => string.Equals(_value, other._value, global::System.StringComparison.InvariantCultureIgnoreCase); + + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Never)] + public override int GetHashCode() => _value != null ? global::System.StringComparer.InvariantCultureIgnoreCase.GetHashCode(_value) : 0; + } +} +""", + typeof(EditorBrowsableAttribute).Assembly.Location); + + Assert.That(generatedText, Is.Not.Null); + Assert.That(generatedText, Does.Contain("[EditorBrowsable(EditorBrowsableState.Never)]")); + Assert.That(generatedText, Does.Contain("StringComparison.InvariantCultureIgnoreCase")); + Assert.That(generatedText, Does.Contain("StringComparer.InvariantCultureIgnoreCase")); + Assert.That(generatedText, Does.Not.Contain("System.ComponentModel.EditorBrowsableAttribute")); + Assert.That(generatedText, Does.Not.Contain("System.StringComparison")); + Assert.That(generatedText, Does.Not.Contain("System.StringComparer")); + } + + [Test] + public async Task GetGeneratedFilesAsync_PreservesQualificationWhenImportedNamespacesContainSameTypeName() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +using First; +using Second; + +namespace TestNamespace +{ + public class Container + { + public global::First.Conflict FirstValue { get; } + public global::Second.Conflict SecondValue { get; } + } +} + +namespace First +{ + public class Conflict { } +} + +namespace Second +{ + public class Conflict { } +} +"""); + + Assert.That(generatedText, Does.Contain("First.Conflict FirstValue")); + Assert.That(generatedText, Does.Contain("Second.Conflict SecondValue")); + Assert.That(generatedText, Does.Not.Contain("public Conflict FirstValue")); + Assert.That(generatedText, Does.Not.Contain("public Conflict SecondValue")); + } + + [Test] + public async Task GetGeneratedFilesAsync_PreservesFrameworkQualificationWhenGeneratedModelShadowsFrameworkType() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +using System; + +namespace TestNamespace +{ + public class BinaryData { } + + public class Container + { + public global::System.BinaryData Payload { get; } + } +} +"""); + + Assert.That(generatedText, Does.Contain("System.BinaryData Payload")); + Assert.That(generatedText, Does.Not.Contain("public BinaryData Payload")); + } + + [Test] + public async Task GetGeneratedFilesAsync_PreservesQualificationWhenCurrentNamespaceTypeShadowsImportedType() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +using External; + +namespace TestNamespace +{ + public class Widget { } + + public class Container + { + public global::External.Widget ExternalWidget { get; } + } +} + +namespace External +{ + public class Widget { } +} +"""); + + Assert.That(generatedText, Does.Contain("External.Widget ExternalWidget")); + Assert.That(generatedText, Does.Not.Contain("public Widget ExternalWidget")); + } + + [Test] + public async Task GetGeneratedFilesAsync_PreservesQualificationWhenParameterNameConflictsWithTypeName() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +using System; + +namespace TestNamespace +{ + public class Container + { + public bool Equals(string StringComparison) => global::System.StringComparison.InvariantCultureIgnoreCase.Equals(StringComparison, StringComparison); + } +} +"""); + + Assert.That(generatedText, Does.Contain("System.StringComparison.InvariantCultureIgnoreCase")); + Assert.That(generatedText, Does.Not.Contain("=> StringComparison.InvariantCultureIgnoreCase")); + } + + [Test] + public async Task GetGeneratedFilesAsync_ReducesGlobalAliasesInXmlDocCrefsSeparatelyFromCode() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +using System; + +namespace TestNamespace +{ + /// See . + public class Container + { + public global::System.ArgumentNullException Create() => null; + } +} +"""); + + Assert.That(generatedText, Does.Contain("")); + Assert.That(generatedText, Does.Contain("ArgumentNullException Create()")); + Assert.That(generatedText, Does.Not.Contain("global::System.ArgumentNullException")); + } + + [Test] + public async Task GetGeneratedFilesAsync_ReducesQualifiedXmlDocCrefs() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +using System; +using System.Text.Json; + +namespace TestNamespace +{ + /// See and . + /// The derived classes available for instantiation are: . + /// A new instance for mocking. + public class Widget + { + public JsonSerializerOptions Options { get; } + } +} +"""); + + Assert.That(generatedText, Does.Contain("")); + Assert.That(generatedText, Does.Contain("derived classes available for instantiation are: ")); + Assert.That(generatedText, Does.Contain("")); + Assert.That(generatedText, Does.Contain(" instance for mocking")); + Assert.That(generatedText, Does.Not.Contain("System.Text.Json.JsonSerializer")); + } + + [Test] + public async Task GetGeneratedFilesAsync_ReducesAliasesGenericNamesAndCustomizationTypesSafely() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +using System.Collections.Generic; +using AliasWidget = Customization.Widget; + +namespace TestNamespace +{ + public class Container + { + public global::System.Collections.Generic.IList Widgets { get; } + public global::Customization.Widget Create(AliasWidget widget) => widget; + public string GetFormat() => ((IPersistableModel)this).GetFormatFromOptions(null); + } + + public interface IPersistableModel + { + string GetFormatFromOptions(object options); + } +} + +namespace Customization +{ + public class Widget { } +} +"""); + + Assert.That(generatedText, Does.Contain("IList Widgets")); + Assert.That(generatedText, Does.Contain("Customization.Widget Create(AliasWidget widget)")); + Assert.That(generatedText, Does.Contain("((IPersistableModel)this).GetFormatFromOptions(null)")); + Assert.That(generatedText, Does.Not.Contain("global::System.Collections.Generic.IList")); + Assert.That(generatedText, Does.Not.Contain("global::Customization.Widget")); + } + + [Test] + public async Task GetGeneratedFilesAsync_ReducesQualifiedGenericTypeNames() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +using System.ClientModel; +using System.Threading.Tasks; + +namespace TestNamespace +{ + public class Widget { } + + public class Operations + { + public Task> GetAsync() => null; + } +} +""", + typeof(System.ClientModel.ClientResult).Assembly.Location); + + Assert.That(generatedText, Does.Contain("Task> GetAsync()")); + Assert.That(generatedText, Does.Not.Contain("System.ClientModel.ClientResult")); + Assert.That(generatedText, Does.Not.Contain("TestNamespace.Widget")); + } + + [Test] + public async Task GetGeneratedFilesAsync_ReducesSameNamespaceStaticMemberAccess() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +namespace TestNamespace +{ + public class Container + { + public void Invoke(string value) + { + TestNamespace.Argument.AssertNotNull(value, nameof(value)); + } + } + + internal static class Argument + { + public static void AssertNotNull(object value, string name) { } + } +} +"""); + + Assert.That(generatedText, Does.Contain("Argument.AssertNotNull(value, nameof(value));")); + Assert.That(generatedText, Does.Not.Contain("TestNamespace.Argument.AssertNotNull")); + } + + [Test] + public async Task GetGeneratedFilesAsync_PreservesStaticMemberQualificationWhenShortNameConflicts() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +namespace TestNamespace +{ + public class Container + { + public void Invoke(string value) + { + var Argument = new LocalArgument(); + TestNamespace.Argument.AssertNotNull(value, nameof(value)); + } + } + + internal static class Argument + { + public static void AssertNotNull(object value, string name) { } + } + + internal class LocalArgument + { + public void AssertNotNull(object value, string name) { } + } +} +"""); + + Assert.That(generatedText, Does.Contain("TestNamespace.Argument.AssertNotNull(value, nameof(value));")); + } + + [Test] + public async Task GetGeneratedFilesAsync_ReducesGeneratedParentheses() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +using System.Collections.Generic; + +namespace TestNamespace +{ + public class Container + { + private Dictionary _map; + private string[] _items; + private int _length; + + public Dictionary Map => (_map ??= new Dictionary()); + + public void Invoke(object writer, WidgetHolder widget, object value, object result, Byte[] bytes, Char[] chars, Options options = (Options)null, string nameHint = (String)null, global::System.String title = (global::System.String)null) + { + if (((value is ICollection collection) && (collection.Count == 0))) + { + Use(((Widget)result)); + } + + if ((_items[(collection.Count - 1)] == null)) + { + return; + } + + switch (collection.Count) + { + case ((>= 200) and (< 300)): + return; + } + + switch (value) + { + case string s when (s.Length > 0): + return; + } + + _map = _map ?? (GetMap() ?? new Dictionary()); + _length = (_length + 1); + string format = (nameHint == "W") ? nameHint : "J"; + string converted = TypeFormatters.ToString(bytes); + writer.WriteObjectValue((Widget)result, options); + widget.Value.Equals(widget.Value); + } + + private void Use(Widget widget) { } + private Dictionary GetMap() => null; + } + + public class Widget { } + public class WidgetHolder + { + public Widget Value { get; } + } + public class Options { } + + internal static class TypeFormatters + { + public static string ToString(byte[] value) => null; + public static string Invoke(byte[] value) => TypeFormatters.ToString(value); + } + + internal static class WriterExtensions + { + public static void WriteObjectValue(this object writer, T value, Options options) { } + } +} +"""); + + Assert.That(generatedText, Does.Contain("Map => _map ??= new Dictionary();")); + Assert.That(generatedText, Does.Contain("if (value is ICollection collection && collection.Count == 0)")); + Assert.That(generatedText, Does.Contain("Use((Widget)result);")); + Assert.That(generatedText, Does.Contain("_items[collection.Count - 1] == null")); + Assert.That(generatedText, Does.Contain("case >= 200 and < 300:")); + Assert.That(generatedText, Does.Contain("byte[] bytes")); + Assert.That(generatedText, Does.Contain("char[] chars")); + Assert.That(generatedText, Does.Contain("Options options = null")); + Assert.That(generatedText, Does.Contain("string nameHint = null")); + Assert.That(generatedText, Does.Contain("string title = null")); + Assert.That(generatedText, Does.Contain("case string s when s.Length > 0:")); + Assert.That(generatedText, Does.Contain("_map = _map ?? GetMap() ?? new Dictionary();")); + Assert.That(generatedText, Does.Contain("_length = _length + 1;")); + Assert.That(generatedText, Does.Contain("string format = nameHint == \"W\" ? nameHint : \"J\";")); + Assert.That(generatedText, Does.Contain("TypeFormatters.ToString(bytes);")); + Assert.That(generatedText, Does.Contain("writer.WriteObjectValue((Widget)result, options);")); + Assert.That(generatedText, Does.Contain("public static string Invoke(byte[] value) => ToString(value);")); + Assert.That(generatedText, Does.Contain("widget.Value.Equals(widget.Value);")); + Assert.That(generatedText, Does.Not.Contain("Use(((Widget)result))")); + Assert.That(generatedText, Does.Not.Contain("if ((value is ICollection collection)")); + Assert.That(generatedText, Does.Not.Contain("Byte[]")); + Assert.That(generatedText, Does.Not.Contain("Char[]")); + Assert.That(generatedText, Does.Not.Contain("(Options)null")); + Assert.That(generatedText, Does.Not.Contain("(String)null")); + } + + [Test] + public async Task GetGeneratedFilesAsync_ReducesThisQualificationSafely() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +namespace TestNamespace +{ + public class Container + { + public string Name { get; } + + public void Invoke() + { + this.Create(this.Name); + } + + private void Create(string name) { } + } +} +"""); + + Assert.That(generatedText, Does.Contain("Create(Name);")); + Assert.That(generatedText, Does.Not.Contain("this.Create")); + Assert.That(generatedText, Does.Not.Contain("this.Name")); + } + + [Test] + public async Task GetGeneratedFilesAsync_PreservesThisQualificationWhenLocalNameConflicts() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +namespace TestNamespace +{ + public class Container + { + public string Name { get; } + + public void Invoke(string Name) + { + this.Create(this.Name); + } + + private void Create(string name) { } + } +} +"""); + + Assert.That(generatedText, Does.Contain("Create(this.Name);")); + } + [Test] public async Task AddPackageReferencesFromProject_AddsReferencesFromCsproj() { @@ -336,6 +845,27 @@ public class Placeholder {{ }} return dllPath; } + private async Task ProcessGeneratedCodeAsync(string content, params string[] additionalMetadataReferencePaths) + { + MockHelpers.LoadMockGenerator( + outputPath: _projectDir, + configuration: "{\"package-name\": \"TestNamespace\"}", + additionalMetadataReferences: additionalMetadataReferencePaths.Select(static path => MetadataReference.CreateFromFile(path))); + + GeneratedCodeWorkspace.Initialize(); + var workspace = await GeneratedCodeWorkspace.Create(false); + await workspace.AddGeneratedFile(new CodeFile(content, "TestFile.cs")); + + string? generatedText = null; + await foreach (var generatedFile in workspace.GetGeneratedFilesAsync()) + { + generatedText = generatedFile.Text; + } + + Assert.That(generatedText, Is.Not.Null); + return generatedText!; + } + private void CreateTestAssemblyAndProjectFile(string nugetCacheDir, string csProjectFileName) { var ns = csProjectFileName.StartsWith("TestNamespaceUnevaluatedFrameworkValue")