-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathUseOfUndefinedFunction.cs
More file actions
143 lines (119 loc) · 6.24 KB
/
UseOfUndefinedFunction.cs
File metadata and controls
143 lines (119 loc) · 6.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
// Copyright (c) 2024, 2026 James Draycott <me@racci.dev>. All Rights Reserved.
// Licensed under the AGPL-3.0-or-later License, See LICENSE in the project root
// for license information.
using System.Management.Automation;
using System.Management.Automation.Language;
using System.Management.Automation.Runspaces;
using Compiler.Module.Compiled;
using NLog;
namespace Compiler.Analyser.Rules;
public class UseOfUndefinedFunction : Rule {
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
/// <summary>
/// A list of all the built-in functions that are provided in a standard session.
/// This includes modules that are imported by default.
/// </summary>
private static readonly HashSet<string> BuiltinsFunctions = [.. GetDefaultSessionFunctions()];
// Don't use a concurrent dictionary as just having it per thread is actually faster.
private static readonly ThreadLocal<Dictionary<Ast, HashSet<string>>> AvailableFunctionsAndAliasesForAst = new(() => []);
private static readonly ThreadLocal<Dictionary<string, HashSet<string>>> AvailableFunctionsAndAliasesForRemote = new(() => []);
public override bool SupportsModule<T>(T compiledModule) => compiledModule is CompiledLocalModule;
public override bool ShouldProcess(
Ast node,
IEnumerable<Suppression> supressions) {
if (node is not CommandAst commandAst) return false;
if (commandAst.GetCommandName() == null) return false;
var callName = SanatiseName(commandAst.GetCommandName());
return !supressions.Any(supression => {
switch (supression.Data) {
case IEnumerable<string> functions:
return functions.Any(function => function.Equals(callName, StringComparison.OrdinalIgnoreCase));
case string function:
return function == callName;
default:
Logger.Warn($"Supression data is not a string or IEnumerable<string> for rule {this.GetType().Name}, received {supression?.Data?.GetType()?.Name}");
return false;
}
});
}
public override IEnumerable<Issue> Analyse(
Ast node,
IEnumerable<Compiled> importedModules) {
var commandAst = (CommandAst)node;
var callName = SanatiseName(commandAst.GetCommandName());
if (BuiltinsFunctions.Contains(callName)) yield break;
var rootNode = AstHelper.FindRoot(node);
if (GetAvailableFunctionsAndAliasesForAst(rootNode).Contains(callName)) yield break;
foreach (var module in importedModules) {
if (module is CompiledLocalModule localModule) {
if (GetAvailableFunctionsAndAliasesForAst(localModule.Document.Ast).Contains(callName)) yield break;
} else {
if (GetAvailableFunctionsAndAliasesForRemote(module).Contains(callName)) yield break;
}
}
yield return Issue.Warning(
$"Undefined function '{commandAst.GetCommandName()}'",
commandAst.CommandElements[0].Extent,
commandAst
);
}
public static string SanatiseName(string name) {
var withOutExtension = name.Contains('.') ? name.Split('.').First() : name;
var withoutScope = withOutExtension.Contains(':') ? withOutExtension.Split(':').Last() : withOutExtension;
return withoutScope.ToLowerInvariant();
}
/// <summary>
/// Get all functions which should always be available in a session.
/// This will collect the following:
/// <list type="bullet">
/// - Builtin functions of powershell
/// - Functions from modules in the default PSModulePath.
/// - On Windows: executables in the system32 directory and WindowsPowerShell modules.
/// - A few manual inclusions to cover some edge cases.
/// </list>
/// </summary>
public static IEnumerable<string> GetDefaultSessionFunctions() {
var defaultFunctions = new List<string>();
var sessionState = InitialSessionState.CreateDefault();
using (var pwsh = PowerShell.Create(sessionState)) {
defaultFunctions.AddRange(pwsh.Runspace.SessionStateProxy.InvokeCommand
.GetCommands("*", CommandTypes.All, true)
.Select(command => command.Name));
}
if (OperatingSystem.IsWindows()) {
var modulesPath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.System), "WindowsPowerShell", "v1.0", "Modules");
if (Directory.Exists(modulesPath)) {
using (var ps = PowerShell.Create().AddScript(/*ps1*/ $$"""
$env:PSModulePath = '{{modulesPath}}';
$env:Path = "${env:SystemRoot}\system32;${env:SystemRoot};${env:SystemRoot}\System32\Wbem;${env:SystemRoot}\System32\WindowsPowerShell\v1.0\;";
$PSModuleAutoLoadingPreference = 'All';
Get-Command * | Select-Object -ExpandProperty Name
""")) {
var psResult = ps.Invoke();
defaultFunctions.AddRange(psResult.Select(commandName => ((string)commandName.BaseObject).Replace(".exe", "")));
}
}
}
return defaultFunctions.Distinct().Select(SanatiseName);
}
private static HashSet<string> GetAvailableFunctionsAndAliasesForAst(Ast rootNode) {
if (!AvailableFunctionsAndAliasesForAst.Value!.TryGetValue(rootNode, out var set)) {
set = [];
set.UnionWith(AstHelper.FindAvailableFunctions(rootNode, false).Select(definition => SanatiseName(definition.Name)));
set.UnionWith(AstHelper.FindAvailableAliases(rootNode, false).Select(SanatiseName));
AvailableFunctionsAndAliasesForAst.Value[rootNode] = set;
}
return set;
}
private static HashSet<string> GetAvailableFunctionsAndAliasesForRemote(Compiled module) {
if (module.ComputedHash().IsErr(out _, out var hash)) {
return [];
}
if (!AvailableFunctionsAndAliasesForRemote.Value!.TryGetValue(hash, out var set)) {
set = [];
set.UnionWith(module.GetExportedFunctions().Select(SanatiseName));
AvailableFunctionsAndAliasesForRemote.Value.TryAdd(hash, set);
}
return set;
}
}