diff --git a/src/FastExpressionCompiler.LightExpression/FlatExpression.cs b/src/FastExpressionCompiler.LightExpression/FlatExpression.cs index 11c12e74..b458e288 100644 --- a/src/FastExpressionCompiler.LightExpression/FlatExpression.cs +++ b/src/FastExpressionCompiler.LightExpression/FlatExpression.cs @@ -45,6 +45,23 @@ public enum ExprNodeKind : byte UInt16Pair, } +/// Maps a lambda node to a parameter identity used from an outer scope and therefore captured in closure. +[StructLayout(LayoutKind.Sequential, Pack = 2)] +public readonly struct LambdaClosureParameterUsage +{ + /// The lambda node index containing the parameter usage. + public readonly short LambdaNodeIndex; + + /// The parameter identity id () referenced from outer scope. + public readonly short ParameterId; + + public LambdaClosureParameterUsage(short lambdaNodeIndex, short parameterId) + { + LambdaNodeIndex = lambdaNodeIndex; + ParameterId = parameterId; + } +} + /// Stores one flat expression node plus its intrusive child-link metadata in 24 bytes on 64-bit runtimes. /// /// Layout (64-bit): Type(8) | Obj(8) | _meta(4) | _data(4) = 24 bytes. @@ -197,6 +214,12 @@ public struct ExprTree /// enabling callers to locate all try regions without a full tree traversal. public SmallList, NoArrayPool> TryCatchNodes; + /// Gets or sets mappings of lambda-node index to captured parameter id for nested-lambda closures. + /// Populated by while flattening System.Linq lambdas; + /// each entry means that the lambda body references a parameter that is not declared as that lambda parameter + /// and not declared as a local block/catch variable in that lambda body scope. + public SmallList, NoArrayPool> LambdaClosureParameterUsages; + /// Adds a parameter node and returns its index. public int Parameter(Type type, string name = null) { @@ -900,6 +923,7 @@ private int AddExpression(SysExpr expression) children.Add(AddExpression(lambda.Parameters[i])); var lambdaIndex = _tree.AddRawExpressionNode(expression.Type, null, expression.NodeType, children); _tree.LambdaNodes.Add(lambdaIndex); + CollectLambdaClosureParameterUsages(lambda, lambdaIndex); return lambdaIndex; } case ExpressionType.Block: @@ -1140,6 +1164,91 @@ private int AddExpression(SysExpr expression) } } + private void CollectLambdaClosureParameterUsages(System.Linq.Expressions.LambdaExpression lambda, int lambdaNodeIndex) + { + var collector = new LambdaClosureUsageCollector(lambda); + collector.Visit(lambda.Body); + + var captured = collector.CapturedParameters; + for (var i = 0; i < captured.Count; ++i) + _tree.LambdaClosureParameterUsages.Add(new LambdaClosureParameterUsage( + checked((short)lambdaNodeIndex), + checked((short)GetId(ref _parameterIds, captured[i])))); + } + + private sealed class LambdaClosureUsageCollector : System.Linq.Expressions.ExpressionVisitor + { + private readonly System.Linq.Expressions.LambdaExpression _lambda; + private readonly List _scopedParameters = new(); + private readonly HashSet _scopedParameterSet = new(ReferenceParameterComparer.Instance); + private readonly HashSet _capturedParameterSet = new(ReferenceParameterComparer.Instance); + + public readonly List CapturedParameters = new(); + + public LambdaClosureUsageCollector(System.Linq.Expressions.LambdaExpression lambda) + { + _lambda = lambda; + for (var i = 0; i < lambda.Parameters.Count; ++i) + { + var parameter = lambda.Parameters[i]; + _scopedParameters.Add(parameter); + _scopedParameterSet.Add(parameter); + } + } + + protected override Expression VisitLambda(System.Linq.Expressions.Expression node) => + // Intentionally skip nested lambdas: each lambda closure map is collected independently + // when that lambda node is visited by the parent Builder traversal. + ReferenceEquals(node, _lambda) ? base.VisitLambda(node) : node; + + protected override Expression VisitParameter(SysParameterExpression node) + { + if (!_scopedParameterSet.Contains(node) && _capturedParameterSet.Add(node)) + CapturedParameters.Add(node); + return node; + } + + protected override Expression VisitBlock(System.Linq.Expressions.BlockExpression node) + { + var initialScopeCount = _scopedParameters.Count; + for (var i = 0; i < node.Variables.Count; ++i) + { + var variable = node.Variables[i]; + _scopedParameters.Add(variable); + _scopedParameterSet.Add(variable); + } + var result = base.VisitBlock(node); + for (var i = _scopedParameters.Count - 1; i >= initialScopeCount; --i) + _scopedParameterSet.Remove(_scopedParameters[i]); + _scopedParameters.RemoveRange(initialScopeCount, _scopedParameters.Count - initialScopeCount); + return result; + } + + protected override SysCatchBlock VisitCatchBlock(SysCatchBlock node) + { + var initialScopeCount = _scopedParameters.Count; + if (node.Variable != null) + { + _scopedParameters.Add(node.Variable); + _scopedParameterSet.Add(node.Variable); + } + var result = base.VisitCatchBlock(node); + for (var i = _scopedParameters.Count - 1; i >= initialScopeCount; --i) + _scopedParameterSet.Remove(_scopedParameters[i]); + _scopedParameters.RemoveRange(initialScopeCount, _scopedParameters.Count - initialScopeCount); + return result; + } + + private sealed class ReferenceParameterComparer : IEqualityComparer + { + public static readonly ReferenceParameterComparer Instance = new(); + + public bool Equals(SysParameterExpression x, SysParameterExpression y) => ReferenceEquals(x, y); + + public int GetHashCode(SysParameterExpression obj) => RuntimeHelpers.GetHashCode(obj); + } + } + private int AddConstant(System.Linq.Expressions.ConstantExpression constant) => _tree.Constant(constant.Value, constant.Type); diff --git a/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs b/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs index edf89600..44808035 100644 --- a/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs +++ b/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs @@ -50,7 +50,9 @@ public int Run() Flat_blocks_with_variables_tracked_from_expression_conversion(); Flat_goto_and_label_nodes_tracked_from_expression_conversion(); Flat_try_catch_nodes_tracked_from_expression_conversion(); - return 33; + Flat_lambda_closure_parameter_usages_tracked_for_nested_lambda_from_expression_conversion(); + Flat_lambda_closure_parameter_usages_excludes_nested_lambda_locals(); + return 35; } @@ -931,5 +933,69 @@ public void Flat_try_catch_nodes_tracked_from_expression_conversion() Asserts.AreEqual(1, fe.TryCatchNodes.Count); } + + public void Flat_lambda_closure_parameter_usages_tracked_for_nested_lambda_from_expression_conversion() + { + var p = SysExpr.Parameter(typeof(int), "p"); + var sysLambda = SysExpr.Lambda>>( + SysExpr.Lambda>(p), + p); + + var fe = sysLambda.ToFlatExpression(); + + Asserts.AreEqual(1, fe.LambdaClosureParameterUsages.Count); + + var nestedLambdaIndex = GetSingleNestedLambdaIndex(ref fe); + var pId = GetParameterIdByName(ref fe, "p"); + + var usage = fe.LambdaClosureParameterUsages[0]; + Asserts.AreEqual(nestedLambdaIndex, usage.LambdaNodeIndex); + Asserts.AreEqual(pId, usage.ParameterId); + } + + public void Flat_lambda_closure_parameter_usages_excludes_nested_lambda_locals() + { + var p = SysExpr.Parameter(typeof(int), "p"); + var local = SysExpr.Variable(typeof(int), "local"); + var nested = SysExpr.Lambda>( + SysExpr.Block(new[] { local }, SysExpr.Assign(local, p), local)); + var sysLambda = SysExpr.Lambda>>(nested, p); + + var fe = sysLambda.ToFlatExpression(); + + Asserts.AreEqual(1, fe.LambdaClosureParameterUsages.Count); + + var pId = GetParameterIdByName(ref fe, "p"); + var localId = GetParameterIdByName(ref fe, "local"); + var usage = fe.LambdaClosureParameterUsages[0]; + Asserts.IsFalse(pId == localId); + Asserts.AreEqual(pId, usage.ParameterId); + } + + private static short GetSingleNestedLambdaIndex(ref ExprTree fe) + { + var nestedLambdaIndex = -1; + for (var i = 0; i < fe.LambdaNodes.Count; ++i) + { + var lambdaIndex = fe.LambdaNodes[i]; + if (lambdaIndex == fe.RootIndex) + continue; + if (nestedLambdaIndex != -1) + throw new InvalidOperationException("Expected a single nested lambda."); + nestedLambdaIndex = lambdaIndex; + } + return checked((short)nestedLambdaIndex); + } + + private static short GetParameterIdByName(ref ExprTree fe, string name) + { + for (var i = 0; i < fe.Nodes.Count; ++i) + { + ref var node = ref fe.Nodes[i]; + if (node.NodeType == ExpressionType.Parameter && string.Equals((string)node.Obj, name, StringComparison.Ordinal)) + return checked((short)node.ChildIdx); + } + throw new InvalidOperationException($"Parameter node '{name}' was not found."); + } } }