diff --git a/src/FastExpressionCompiler.LightExpression/FlatExpression.cs b/src/FastExpressionCompiler.LightExpression/FlatExpression.cs
index 11c12e74..b458e288 100644
--- a/src/FastExpressionCompiler.LightExpression/FlatExpression.cs
+++ b/src/FastExpressionCompiler.LightExpression/FlatExpression.cs
@@ -45,6 +45,23 @@ public enum ExprNodeKind : byte
UInt16Pair,
}
+/// Maps a lambda node to a parameter identity used from an outer scope and therefore captured in closure.
+[StructLayout(LayoutKind.Sequential, Pack = 2)]
+public readonly struct LambdaClosureParameterUsage
+{
+ /// The lambda node index containing the parameter usage.
+ public readonly short LambdaNodeIndex;
+
+ /// The parameter identity id () referenced from outer scope.
+ public readonly short ParameterId;
+
+ public LambdaClosureParameterUsage(short lambdaNodeIndex, short parameterId)
+ {
+ LambdaNodeIndex = lambdaNodeIndex;
+ ParameterId = parameterId;
+ }
+}
+
/// Stores one flat expression node plus its intrusive child-link metadata in 24 bytes on 64-bit runtimes.
///
/// Layout (64-bit): Type(8) | Obj(8) | _meta(4) | _data(4) = 24 bytes.
@@ -197,6 +214,12 @@ public struct ExprTree
/// enabling callers to locate all try regions without a full tree traversal.
public SmallList, NoArrayPool> TryCatchNodes;
+ /// Gets or sets mappings of lambda-node index to captured parameter id for nested-lambda closures.
+ /// Populated by while flattening System.Linq lambdas;
+ /// each entry means that the lambda body references a parameter that is not declared as that lambda parameter
+ /// and not declared as a local block/catch variable in that lambda body scope.
+ public SmallList, NoArrayPool> LambdaClosureParameterUsages;
+
/// Adds a parameter node and returns its index.
public int Parameter(Type type, string name = null)
{
@@ -900,6 +923,7 @@ private int AddExpression(SysExpr expression)
children.Add(AddExpression(lambda.Parameters[i]));
var lambdaIndex = _tree.AddRawExpressionNode(expression.Type, null, expression.NodeType, children);
_tree.LambdaNodes.Add(lambdaIndex);
+ CollectLambdaClosureParameterUsages(lambda, lambdaIndex);
return lambdaIndex;
}
case ExpressionType.Block:
@@ -1140,6 +1164,91 @@ private int AddExpression(SysExpr expression)
}
}
+ private void CollectLambdaClosureParameterUsages(System.Linq.Expressions.LambdaExpression lambda, int lambdaNodeIndex)
+ {
+ var collector = new LambdaClosureUsageCollector(lambda);
+ collector.Visit(lambda.Body);
+
+ var captured = collector.CapturedParameters;
+ for (var i = 0; i < captured.Count; ++i)
+ _tree.LambdaClosureParameterUsages.Add(new LambdaClosureParameterUsage(
+ checked((short)lambdaNodeIndex),
+ checked((short)GetId(ref _parameterIds, captured[i]))));
+ }
+
+ private sealed class LambdaClosureUsageCollector : System.Linq.Expressions.ExpressionVisitor
+ {
+ private readonly System.Linq.Expressions.LambdaExpression _lambda;
+ private readonly List _scopedParameters = new();
+ private readonly HashSet _scopedParameterSet = new(ReferenceParameterComparer.Instance);
+ private readonly HashSet _capturedParameterSet = new(ReferenceParameterComparer.Instance);
+
+ public readonly List CapturedParameters = new();
+
+ public LambdaClosureUsageCollector(System.Linq.Expressions.LambdaExpression lambda)
+ {
+ _lambda = lambda;
+ for (var i = 0; i < lambda.Parameters.Count; ++i)
+ {
+ var parameter = lambda.Parameters[i];
+ _scopedParameters.Add(parameter);
+ _scopedParameterSet.Add(parameter);
+ }
+ }
+
+ protected override Expression VisitLambda(System.Linq.Expressions.Expression node) =>
+ // Intentionally skip nested lambdas: each lambda closure map is collected independently
+ // when that lambda node is visited by the parent Builder traversal.
+ ReferenceEquals(node, _lambda) ? base.VisitLambda(node) : node;
+
+ protected override Expression VisitParameter(SysParameterExpression node)
+ {
+ if (!_scopedParameterSet.Contains(node) && _capturedParameterSet.Add(node))
+ CapturedParameters.Add(node);
+ return node;
+ }
+
+ protected override Expression VisitBlock(System.Linq.Expressions.BlockExpression node)
+ {
+ var initialScopeCount = _scopedParameters.Count;
+ for (var i = 0; i < node.Variables.Count; ++i)
+ {
+ var variable = node.Variables[i];
+ _scopedParameters.Add(variable);
+ _scopedParameterSet.Add(variable);
+ }
+ var result = base.VisitBlock(node);
+ for (var i = _scopedParameters.Count - 1; i >= initialScopeCount; --i)
+ _scopedParameterSet.Remove(_scopedParameters[i]);
+ _scopedParameters.RemoveRange(initialScopeCount, _scopedParameters.Count - initialScopeCount);
+ return result;
+ }
+
+ protected override SysCatchBlock VisitCatchBlock(SysCatchBlock node)
+ {
+ var initialScopeCount = _scopedParameters.Count;
+ if (node.Variable != null)
+ {
+ _scopedParameters.Add(node.Variable);
+ _scopedParameterSet.Add(node.Variable);
+ }
+ var result = base.VisitCatchBlock(node);
+ for (var i = _scopedParameters.Count - 1; i >= initialScopeCount; --i)
+ _scopedParameterSet.Remove(_scopedParameters[i]);
+ _scopedParameters.RemoveRange(initialScopeCount, _scopedParameters.Count - initialScopeCount);
+ return result;
+ }
+
+ private sealed class ReferenceParameterComparer : IEqualityComparer
+ {
+ public static readonly ReferenceParameterComparer Instance = new();
+
+ public bool Equals(SysParameterExpression x, SysParameterExpression y) => ReferenceEquals(x, y);
+
+ public int GetHashCode(SysParameterExpression obj) => RuntimeHelpers.GetHashCode(obj);
+ }
+ }
+
private int AddConstant(System.Linq.Expressions.ConstantExpression constant) =>
_tree.Constant(constant.Value, constant.Type);
diff --git a/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs b/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs
index edf89600..44808035 100644
--- a/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs
+++ b/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs
@@ -50,7 +50,9 @@ public int Run()
Flat_blocks_with_variables_tracked_from_expression_conversion();
Flat_goto_and_label_nodes_tracked_from_expression_conversion();
Flat_try_catch_nodes_tracked_from_expression_conversion();
- return 33;
+ Flat_lambda_closure_parameter_usages_tracked_for_nested_lambda_from_expression_conversion();
+ Flat_lambda_closure_parameter_usages_excludes_nested_lambda_locals();
+ return 35;
}
@@ -931,5 +933,69 @@ public void Flat_try_catch_nodes_tracked_from_expression_conversion()
Asserts.AreEqual(1, fe.TryCatchNodes.Count);
}
+
+ public void Flat_lambda_closure_parameter_usages_tracked_for_nested_lambda_from_expression_conversion()
+ {
+ var p = SysExpr.Parameter(typeof(int), "p");
+ var sysLambda = SysExpr.Lambda>>(
+ SysExpr.Lambda>(p),
+ p);
+
+ var fe = sysLambda.ToFlatExpression();
+
+ Asserts.AreEqual(1, fe.LambdaClosureParameterUsages.Count);
+
+ var nestedLambdaIndex = GetSingleNestedLambdaIndex(ref fe);
+ var pId = GetParameterIdByName(ref fe, "p");
+
+ var usage = fe.LambdaClosureParameterUsages[0];
+ Asserts.AreEqual(nestedLambdaIndex, usage.LambdaNodeIndex);
+ Asserts.AreEqual(pId, usage.ParameterId);
+ }
+
+ public void Flat_lambda_closure_parameter_usages_excludes_nested_lambda_locals()
+ {
+ var p = SysExpr.Parameter(typeof(int), "p");
+ var local = SysExpr.Variable(typeof(int), "local");
+ var nested = SysExpr.Lambda>(
+ SysExpr.Block(new[] { local }, SysExpr.Assign(local, p), local));
+ var sysLambda = SysExpr.Lambda>>(nested, p);
+
+ var fe = sysLambda.ToFlatExpression();
+
+ Asserts.AreEqual(1, fe.LambdaClosureParameterUsages.Count);
+
+ var pId = GetParameterIdByName(ref fe, "p");
+ var localId = GetParameterIdByName(ref fe, "local");
+ var usage = fe.LambdaClosureParameterUsages[0];
+ Asserts.IsFalse(pId == localId);
+ Asserts.AreEqual(pId, usage.ParameterId);
+ }
+
+ private static short GetSingleNestedLambdaIndex(ref ExprTree fe)
+ {
+ var nestedLambdaIndex = -1;
+ for (var i = 0; i < fe.LambdaNodes.Count; ++i)
+ {
+ var lambdaIndex = fe.LambdaNodes[i];
+ if (lambdaIndex == fe.RootIndex)
+ continue;
+ if (nestedLambdaIndex != -1)
+ throw new InvalidOperationException("Expected a single nested lambda.");
+ nestedLambdaIndex = lambdaIndex;
+ }
+ return checked((short)nestedLambdaIndex);
+ }
+
+ private static short GetParameterIdByName(ref ExprTree fe, string name)
+ {
+ for (var i = 0; i < fe.Nodes.Count; ++i)
+ {
+ ref var node = ref fe.Nodes[i];
+ if (node.NodeType == ExpressionType.Parameter && string.Equals((string)node.Obj, name, StringComparison.Ordinal))
+ return checked((short)node.ChildIdx);
+ }
+ throw new InvalidOperationException($"Parameter node '{name}' was not found.");
+ }
}
}