diff --git a/src/FastExpressionCompiler.LightExpression/FlatExpression.cs b/src/FastExpressionCompiler.LightExpression/FlatExpression.cs index 52654ef8..16fb28f8 100644 --- a/src/FastExpressionCompiler.LightExpression/FlatExpression.cs +++ b/src/FastExpressionCompiler.LightExpression/FlatExpression.cs @@ -23,6 +23,12 @@ public enum ExprNodeKind : byte { /// Represents a regular expression node. Expression, + /// Represents a parameter declaration node. + ParameterDeclaration, + /// Represents a parameter usage node that points to its declaration node index. + ParameterUsage, + /// Represents a generic reference to an already-linked node. + NodeReference, /// Represents a switch case payload. SwitchCase, /// Represents a catch block payload. @@ -49,17 +55,9 @@ public enum ExprNodeKind : byte [StructLayout(LayoutKind.Explicit, Size = 24)] public struct ExprNode { - private const int NodeTypeShift = 56; - private const int TagShift = 48; - private const int NextShift = 32; - private const int CountShift = 16; - private const ulong IndexMask = 0xFFFF; - private const ulong KindMask = 0x0F; - private const ulong NextMask = IndexMask << NextShift; - private const ulong ChildCountMask = IndexMask << CountShift; - private const ulong ChildInfoMask = ChildCountMask | IndexMask; - private const ulong KeepWithoutNextMask = ~NextMask; - private const ulong KeepWithoutChildInfoMask = ~ChildInfoMask; + private const byte KindMask = 0x0F; + private const byte NextReservedFlag = 0x4; + private const byte NextPointsParentFlag = 0x8; private const int FlagsShift = 4; /// Gets or sets the runtime type of the represented node. @@ -70,66 +68,97 @@ public struct ExprNode [FieldOffset(8)] public object Obj; [FieldOffset(16)] - private ulong _data; + private ushort _childIdx; + [FieldOffset(18)] + private ushort _childCount; + [FieldOffset(20)] + private ushort _nextIdx; + [FieldOffset(22)] + private byte _tag; + [FieldOffset(23)] + private byte _nodeType; /// Gets the expression kind encoded for this node. - public ExpressionType NodeType => (ExpressionType)((_data >> NodeTypeShift) & 0xFF); + public ExpressionType NodeType => (ExpressionType)_nodeType; /// Gets the payload classification for this node. - public ExprNodeKind Kind => (ExprNodeKind)((_data >> TagShift) & KindMask); + public ExprNodeKind Kind => (ExprNodeKind)(_tag & KindMask); - internal byte Flags => (byte)(((byte)(_data >> TagShift)) >> FlagsShift); + internal byte Flags => (byte)(_tag >> FlagsShift); /// Gets the next sibling node index in the intrusive child chain. - public int NextIdx => (int)((_data >> NextShift) & IndexMask); + public int NextIdx => _nextIdx; + + internal bool IsParentLink => (Flags & NextPointsParentFlag) != 0; + + internal bool HasNextLink => _nextIdx != 0 || (Flags & (NextPointsParentFlag | NextReservedFlag)) != 0; /// Gets the number of direct children linked from this node. - public int ChildCount => (int)((_data >> CountShift) & IndexMask); + public int ChildCount => _childCount; /// Gets the first child index or an auxiliary payload index. - public int ChildIdx => (int)(_data & IndexMask); + public int ChildIdx => _childIdx; internal ExprNode(Type type, object obj, ExpressionType nodeType, ExprNodeKind kind, byte flags = 0, int childIdx = 0, int childCount = 0, int nextIdx = 0) { Type = type; Obj = obj; - var tag = (byte)((flags << FlagsShift) | (byte)kind); - _data = ((ulong)(byte)nodeType << NodeTypeShift) - | ((ulong)tag << TagShift) - | ((ulong)(ushort)nextIdx << NextShift) - | ((ulong)(ushort)childCount << CountShift) - | (ushort)childIdx; + _childIdx = (ushort)childIdx; + _childCount = (ushort)childCount; + _nextIdx = (ushort)nextIdx; + _tag = (byte)((flags << FlagsShift) | (byte)kind); + _nodeType = (byte)nodeType; } [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void SetNextIdx(int nextIdx) => - _data = (_data & KeepWithoutNextMask) | ((ulong)(ushort)nextIdx << NextShift); + internal void SetNextSiblingIdx(int nextIdx) + { + _nextIdx = (ushort)nextIdx; + SetFlags((byte)(Flags & ~(NextPointsParentFlag | NextReservedFlag))); + } [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void SetChildInfo(int childIdx, int childCount) => - _data = (_data & KeepWithoutChildInfoMask) - | ((ulong)(ushort)childCount << CountShift) - | (ushort)childIdx; + internal void SetParentIdx(int parentIdx) + { + _nextIdx = (ushort)parentIdx; + SetFlags((byte)((Flags | NextPointsParentFlag) & ~NextReservedFlag)); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void SetChildInfo(int childIdx, int childCount) + { + _childIdx = (ushort)childIdx; + _childCount = (ushort)childCount; + } [MethodImpl(MethodImplOptions.AggressiveInlining)] internal bool Is(ExprNodeKind kind) => Kind == kind; [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal bool IsExpression() => Kind == ExprNodeKind.Expression; + internal bool IsExpression() => + Kind == ExprNodeKind.Expression || Kind == ExprNodeKind.ParameterDeclaration || Kind == ExprNodeKind.ParameterUsage; [MethodImpl(MethodImplOptions.AggressiveInlining)] internal bool HasFlag(byte flag) => (Flags & flag) != 0; + internal byte CopyableFlags => (byte)(Flags & ~(NextPointsParentFlag | NextReservedFlag)); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void ReserveAsLinked() => SetFlags((byte)(Flags | NextReservedFlag)); + [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal bool ShouldCloneWhenLinked() => - Kind == ExprNodeKind.LabelTarget || NodeType == ExpressionType.Parameter || Kind == ExprNodeKind.ObjectReference || ChildCount == 0; + private void SetFlags(byte flags) + { + _tag = (byte)((flags << FlagsShift) | (byte)Kind); + } } /// Stores an expression tree as a flat node array plus out-of-line closure constants. public struct ExprTree { - private static readonly object ClosureConstantMarker = new(); private const byte ParameterByRefFlag = 1; + private const byte ConstantInClosureFlag = 2; + private const int UnboundParameterPosition = ushort.MaxValue; private const byte BinaryLiftedToNullFlag = 1; private const byte LoopHasBreakFlag = 1; private const byte LoopHasContinueFlag = 2; @@ -143,14 +172,15 @@ public struct ExprTree /// Gets or sets the flat node storage. public SmallList, NoArrayPool> Nodes; - /// Gets or sets closure constants that are referenced from constant nodes. + /// Gets or sets non-inlined constants referenced from constant nodes via . public SmallList, NoArrayPool> ClosureConstants; /// Adds a parameter node and returns its index. public int Parameter(Type type, string name = null) { - var id = Nodes.Count + 1; - return AddRawLeafExpressionNode(type, name, ExpressionType.Parameter, type.IsByRef ? ParameterByRefFlag : (byte)0, childIdx: id); + var parameterType = type ?? throw new ArgumentNullException(nameof(type)); + return AddLeafNode(parameterType, name, ExpressionType.Parameter, ExprNodeKind.ParameterDeclaration, + parameterType.IsByRef ? ParameterByRefFlag : (byte)0, childIdx: 0, childCount: UnboundParameterPosition); } /// Adds a typed parameter node and returns its index. @@ -173,7 +203,7 @@ public int Constant(object value, Type type) return AddRawExpressionNode(type, value, ExpressionType.Constant); var constantIndex = ClosureConstants.Add(value); - return AddRawExpressionNodeWithChildIndex(type, ClosureConstantMarker, ExpressionType.Constant, constantIndex); + return AddRawExpressionNodeWithChildIndex(type, null, ExpressionType.Constant, ConstantInClosureFlag, constantIndex); } /// Adds a null constant node. @@ -310,19 +340,27 @@ public int Block(Type type, IEnumerable variables, params int[] expressions throw new ArgumentException("Block should contain at least one expression.", nameof(expressions)); ChildList children = default; + ChildList variableDeclarations = default; if (variables != null) { ChildList variableChildren = default; foreach (var variable in variables) - variableChildren.Add(variable); + { + var declaration = NormalizeParameterDeclarationIndex(variable); + variableChildren.Add(declaration); + variableDeclarations.Add(declaration); + } if (variableChildren.Count != 0) children.Add(AddChildListNode(in variableChildren)); } ChildList bodyChildren = default; for (var i = 0; i < expressions.Length; ++i) - bodyChildren.Add(expressions[i]); + bodyChildren.Add(CloneChild(expressions[i])); children.Add(AddChildListNode(in bodyChildren)); - return AddFactoryExpressionNode(type ?? Nodes[expressions[expressions.Length - 1]].Type, null, ExpressionType.Block, in children); + var blockIndex = AddFactoryExpressionNode(type ?? Nodes[expressions[expressions.Length - 1]].Type, null, ExpressionType.Block, in children); + for (var i = 0; i < variableDeclarations.Count; ++i) + BindParameterDeclaration(variableDeclarations[i], blockIndex, i); + return blockIndex; } /// Adds a typed lambda node. @@ -330,10 +368,20 @@ public int Lambda(int body, params int[] parameters) where TDelegate Lambda(typeof(TDelegate), body, parameters); /// Adds a lambda node. - public int Lambda(Type delegateType, int body, params int[] parameters) => - parameters == null || parameters.Length == 0 - ? AddFactoryExpressionNode(delegateType, null, ExpressionType.Lambda, 0, body) - : AddFactoryExpressionNode(delegateType, null, ExpressionType.Lambda, PrependToChildList(body, parameters)); + public int Lambda(Type delegateType, int body, params int[] parameters) + { + if (parameters == null || parameters.Length == 0) + return AddFactoryExpressionNode(delegateType, null, ExpressionType.Lambda, 0, body); + + var declarations = new int[parameters.Length]; + for (var i = 0; i < parameters.Length; ++i) + declarations[i] = NormalizeParameterDeclarationIndex(parameters[i]); + + var lambdaIndex = AddFactoryExpressionNode(delegateType, null, ExpressionType.Lambda, PrependToChildList(body, declarations)); + for (var i = 0; i < declarations.Length; ++i) + BindParameterDeclaration(declarations[i], lambdaIndex, i); + return lambdaIndex; + } /// Adds a member-assignment binding node. public int Bind(System.Reflection.MemberInfo member, int expression) => @@ -545,6 +593,40 @@ public SysExpr ToExpression() => [RequiresUnreferencedCode(FastExpressionCompiler.LightExpression.Trimming.Message)] public LightExpression ToLightExpression() => FastExpressionCompiler.LightExpression.FromSysExpressionConverter.ToLightExpression(ToExpression()); + private int NormalizeParameterDeclarationIndex(int parameterIndex) + { + ref var node = ref Nodes[parameterIndex]; + if (node.Is(ExprNodeKind.ParameterUsage)) + return node.ChildIdx; + if (node.Is(ExprNodeKind.ParameterDeclaration)) + return parameterIndex; + throw new InvalidOperationException($"Node at index {parameterIndex} is not a parameter declaration or usage."); + } + + private void BindParameterDeclaration(int declarationIndex, int ownerIndex, int position) + { + ref var declaration = ref Nodes[declarationIndex]; + if (!declaration.Is(ExprNodeKind.ParameterDeclaration)) + throw new InvalidOperationException($"Node at index {declarationIndex} is not a parameter declaration."); + if (declaration.ChildCount != UnboundParameterPosition) + throw new InvalidOperationException($"Parameter declaration at index {declarationIndex} is already bound to an owner scope."); + declaration.SetChildInfo(ownerIndex, position); + } + + private int AddParameterUsageNode(int declarationIndex) + { + ref var declaration = ref Nodes[declarationIndex]; + Debug.Assert(declaration.Is(ExprNodeKind.ParameterDeclaration)); + return AddLeafNode(declaration.Type, declaration.Obj, ExpressionType.Parameter, ExprNodeKind.ParameterUsage, + declaration.CopyableFlags, declarationIndex, 0); + } + + private int AddNodeReference(int index) + { + ref var node = ref Nodes[index]; + return AddLeafNode(node.Type, node.Obj, node.NodeType, ExprNodeKind.NodeReference, node.CopyableFlags, index, 0); + } + private int AddFactoryExpressionNode(Type type, object obj, ExpressionType nodeType, int child) => AddNode(type, obj, nodeType, ExprNodeKind.Expression, 0, CloneChild(child)); @@ -614,8 +696,8 @@ private int AddRawExpressionNode(Type type, object obj, ExpressionType nodeType, private int AddRawLeafExpressionNode(Type type, object obj, ExpressionType nodeType, byte flags = 0, int childIdx = 0, int childCount = 0) => AddLeafNode(type, obj, nodeType, ExprNodeKind.Expression, flags, childIdx, childCount); - private int AddRawExpressionNodeWithChildIndex(Type type, object obj, ExpressionType nodeType, int childIdx) => - AddRawLeafExpressionNode(type, obj, nodeType, childIdx: childIdx); + private int AddRawExpressionNodeWithChildIndex(Type type, object obj, ExpressionType nodeType, byte flags, int childIdx) => + AddRawLeafExpressionNode(type, obj, nodeType, flags, childIdx, 0); private int AddFactoryAuxNode(Type type, object obj, ExprNodeKind kind, byte flags, int child) => AddNode(type, obj, ExpressionType.Extension, kind, flags, CloneChild(child)); @@ -698,34 +780,49 @@ private int AddExpression(SysExpr expression) case ExpressionType.Parameter: { var parameter = (SysParameterExpression)expression; - return _tree.AddRawLeafExpressionNode(expression.Type, parameter.Name, expression.NodeType, - parameter.IsByRef ? ParameterByRefFlag : (byte)0, childIdx: GetId(ref _parameterIds, parameter)); + return _tree.AddParameterUsageNode(GetParameterDeclarationIndex(parameter)); } case ExpressionType.Lambda: { var lambda = (System.Linq.Expressions.LambdaExpression)expression; ChildList children = default; children.Add(AddExpression(lambda.Body)); + ChildList declarations = default; for (var i = 0; i < lambda.Parameters.Count; ++i) - children.Add(AddExpression(lambda.Parameters[i])); - return _tree.AddRawExpressionNode(expression.Type, null, expression.NodeType, children); + { + var declaration = GetParameterDeclarationIndex(lambda.Parameters[i]); + declarations.Add(declaration); + children.Add(declaration); + } + var lambdaIndex = _tree.AddRawExpressionNode(expression.Type, null, expression.NodeType, children); + for (var i = 0; i < declarations.Count; ++i) + _tree.BindParameterDeclaration(declarations[i], lambdaIndex, i); + return lambdaIndex; } case ExpressionType.Block: { var block = (System.Linq.Expressions.BlockExpression)expression; ChildList children = default; + ChildList declarations = default; if (block.Variables.Count != 0) { ChildList variables = default; for (var i = 0; i < block.Variables.Count; ++i) - variables.Add(AddExpression(block.Variables[i])); + { + var declaration = GetParameterDeclarationIndex(block.Variables[i]); + declarations.Add(declaration); + variables.Add(declaration); + } children.Add(_tree.AddChildListNode(in variables)); } ChildList expressions = default; for (var i = 0; i < block.Expressions.Count; ++i) expressions.Add(AddExpression(block.Expressions[i])); children.Add(_tree.AddChildListNode(in expressions)); - return _tree.AddRawExpressionNode(expression.Type, null, expression.NodeType, in children); + var blockIndex = _tree.AddRawExpressionNode(expression.Type, null, expression.NodeType, in children); + for (var i = 0; i < declarations.Count; ++i) + _tree.BindParameterDeclaration(declarations[i], blockIndex, i); + return blockIndex; } case ExpressionType.MemberAccess: { @@ -940,7 +1037,7 @@ private int AddConstant(System.Linq.Expressions.ConstantExpression constant) return _tree.AddRawExpressionNode(constant.Type, constant.Value, constant.NodeType); var constantIndex = _tree.ClosureConstants.Add(constant.Value); - return _tree.AddRawExpressionNodeWithChildIndex(constant.Type, ClosureConstantMarker, constant.NodeType, constantIndex); + return _tree.AddRawExpressionNodeWithChildIndex(constant.Type, null, constant.NodeType, ConstantInClosureFlag, constantIndex); } private int AddSwitchCase(SysSwitchCase switchCase) @@ -1005,6 +1102,14 @@ private int AddElementInit(SysElementInit init) return _tree.AddRawAuxNode(init.AddMethod.DeclaringType, init.AddMethod, ExprNodeKind.ElementInit, children); } + private int GetParameterDeclarationIndex(SysParameterExpression parameter) + { + ref var declaration = ref _parameterIds.Map.AddOrGetValueRef(parameter, out var found); + if (!found) + declaration = _tree.Parameter(parameter.Type, parameter.Name); + return declaration; + } + private static int GetId(ref SmallMap16> ids, object item) { ref var id = ref ids.Map.AddOrGetValueRef(item, out var found); @@ -1042,6 +1147,7 @@ private int AddNode(Type type, object obj, ExpressionType nodeType, ExprNodeKind var nodeIndex = Nodes.Count; ref var newNode = ref Nodes.AddDefaultAndGetRef(); newNode = new ExprNode(type, obj, nodeType, kind, flags, child0, 1); + Nodes[child0].SetParentIdx(nodeIndex); return nodeIndex; } @@ -1050,7 +1156,8 @@ private int AddNode(Type type, object obj, ExpressionType nodeType, ExprNodeKind var nodeIndex = Nodes.Count; ref var newNode = ref Nodes.AddDefaultAndGetRef(); newNode = new ExprNode(type, obj, nodeType, kind, flags, c0, 2); - Nodes[c0].SetNextIdx(c1); + Nodes[c0].SetNextSiblingIdx(c1); + Nodes[c1].SetParentIdx(nodeIndex); return nodeIndex; } @@ -1059,8 +1166,9 @@ private int AddNode(Type type, object obj, ExpressionType nodeType, ExprNodeKind var nodeIndex = Nodes.Count; ref var newNode = ref Nodes.AddDefaultAndGetRef(); newNode = new ExprNode(type, obj, nodeType, kind, flags, c0, 3); - Nodes[c0].SetNextIdx(c1); - Nodes[c1].SetNextIdx(c2); + Nodes[c0].SetNextSiblingIdx(c1); + Nodes[c1].SetNextSiblingIdx(c2); + Nodes[c2].SetParentIdx(nodeIndex); return nodeIndex; } @@ -1069,9 +1177,10 @@ private int AddNode(Type type, object obj, ExpressionType nodeType, ExprNodeKind var nodeIndex = Nodes.Count; ref var newNode = ref Nodes.AddDefaultAndGetRef(); newNode = new ExprNode(type, obj, nodeType, kind, flags, c0, 4); - Nodes[c0].SetNextIdx(c1); - Nodes[c1].SetNextIdx(c2); - Nodes[c2].SetNextIdx(c3); + Nodes[c0].SetNextSiblingIdx(c1); + Nodes[c1].SetNextSiblingIdx(c2); + Nodes[c2].SetNextSiblingIdx(c3); + Nodes[c3].SetParentIdx(nodeIndex); return nodeIndex; } @@ -1080,10 +1189,11 @@ private int AddNode(Type type, object obj, ExpressionType nodeType, ExprNodeKind var nodeIndex = Nodes.Count; ref var newNode = ref Nodes.AddDefaultAndGetRef(); newNode = new ExprNode(type, obj, nodeType, kind, flags, c0, 5); - Nodes[c0].SetNextIdx(c1); - Nodes[c1].SetNextIdx(c2); - Nodes[c2].SetNextIdx(c3); - Nodes[c3].SetNextIdx(c4); + Nodes[c0].SetNextSiblingIdx(c1); + Nodes[c1].SetNextSiblingIdx(c2); + Nodes[c2].SetNextSiblingIdx(c3); + Nodes[c3].SetNextSiblingIdx(c4); + Nodes[c4].SetParentIdx(nodeIndex); return nodeIndex; } @@ -1092,11 +1202,12 @@ private int AddNode(Type type, object obj, ExpressionType nodeType, ExprNodeKind var nodeIndex = Nodes.Count; ref var newNode = ref Nodes.AddDefaultAndGetRef(); newNode = new ExprNode(type, obj, nodeType, kind, flags, c0, 6); - Nodes[c0].SetNextIdx(c1); - Nodes[c1].SetNextIdx(c2); - Nodes[c2].SetNextIdx(c3); - Nodes[c3].SetNextIdx(c4); - Nodes[c4].SetNextIdx(c5); + Nodes[c0].SetNextSiblingIdx(c1); + Nodes[c1].SetNextSiblingIdx(c2); + Nodes[c2].SetNextSiblingIdx(c3); + Nodes[c3].SetNextSiblingIdx(c4); + Nodes[c4].SetNextSiblingIdx(c5); + Nodes[c5].SetParentIdx(nodeIndex); return nodeIndex; } @@ -1105,12 +1216,13 @@ private int AddNode(Type type, object obj, ExpressionType nodeType, ExprNodeKind var nodeIndex = Nodes.Count; ref var newNode = ref Nodes.AddDefaultAndGetRef(); newNode = new ExprNode(type, obj, nodeType, kind, flags, c0, 7); - Nodes[c0].SetNextIdx(c1); - Nodes[c1].SetNextIdx(c2); - Nodes[c2].SetNextIdx(c3); - Nodes[c3].SetNextIdx(c4); - Nodes[c4].SetNextIdx(c5); - Nodes[c5].SetNextIdx(c6); + Nodes[c0].SetNextSiblingIdx(c1); + Nodes[c1].SetNextSiblingIdx(c2); + Nodes[c2].SetNextSiblingIdx(c3); + Nodes[c3].SetNextSiblingIdx(c4); + Nodes[c4].SetNextSiblingIdx(c5); + Nodes[c5].SetNextSiblingIdx(c6); + Nodes[c6].SetParentIdx(nodeIndex); return nodeIndex; } @@ -1123,7 +1235,8 @@ private int AddNode(Type type, object obj, ExpressionType nodeType, ExprNodeKind ref var newNode = ref Nodes.AddDefaultAndGetRef(); newNode = new ExprNode(type, obj, nodeType, kind, flags, children[0], children.Length); for (var i = 1; i < children.Length; ++i) - Nodes[children[i - 1]].SetNextIdx(children[i]); + Nodes[children[i - 1]].SetNextSiblingIdx(children[i]); + Nodes[children[children.Length - 1]].SetParentIdx(nodeIndex); return nodeIndex; } @@ -1136,12 +1249,14 @@ private int AddNode(Type type, object obj, ExpressionType nodeType, ExprNodeKind ref var newNode = ref Nodes.AddDefaultAndGetRef(); newNode = new ExprNode(type, obj, nodeType, kind, flags, children[0], children.Count); for (var i = 1; i < children.Count; ++i) - Nodes[children[i - 1]].SetNextIdx(children[i]); + Nodes[children[i - 1]].SetNextSiblingIdx(children[i]); + Nodes[children[children.Count - 1]].SetParentIdx(nodeIndex); return nodeIndex; } [MethodImpl(MethodImplOptions.AggressiveInlining)] private static bool ShouldInlineConstant(object value, Type type) => + // Inlined constants are stored directly in ExprNode.Obj (boxed for value types). value == null || value is string || value is Type || type.IsEnum || Type.GetTypeCode(type) != TypeCode.Object; private static Type GetMemberType(System.Reflection.MemberInfo member) => member switch @@ -1184,9 +1299,22 @@ private static Type GetArrayElementType(Type arrayType, int depth) private int CloneChild(int index) { ref var node = ref Nodes[index]; - return node.ShouldCloneWhenLinked() - ? AddLeafNode(node.Type, node.Obj, node.NodeType, node.Kind, node.Flags, node.ChildIdx, node.ChildCount) - : index; + if (node.Is(ExprNodeKind.ParameterDeclaration)) + return AddParameterUsageNode(index); + if (node.HasNextLink) + return AddNodeReference(index); + + ReserveChildLinkForReuse(ref node); + return index; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void ReserveChildLinkForReuse(ref ExprNode node) + { + // Mark the node as "already linked" before wiring siblings/parent on the enclosing AddNode call. + // This ensures that repeated use of the same index in a single parent (e.g. Add(x, x)) + // keeps the first occurrence in-place and emits a reference node for later occurrences. + node.ReserveAsLinked(); } private ChildList CloneChildren(int[] children) @@ -1225,6 +1353,7 @@ public Reader(ExprTree tree) [RequiresUnreferencedCode(FastExpressionCompiler.LightExpression.Trimming.Message)] public SysExpr ReadExpression(int index) { + index = ResolveReferenceIndex(index); ref var node = ref _tree.Nodes[index]; if (!node.IsExpression()) throw new InvalidOperationException($"Node at index {index} is not an expression node."); @@ -1232,14 +1361,15 @@ public SysExpr ReadExpression(int index) switch (node.NodeType) { case ExpressionType.Constant: - return SysExpr.Constant(ReferenceEquals(node.Obj, ClosureConstantMarker) + return SysExpr.Constant(node.HasFlag(ConstantInClosureFlag) ? _tree.ClosureConstants[node.ChildIdx] : node.Obj, node.Type); case ExpressionType.Default: return SysExpr.Default(node.Type); case ExpressionType.Parameter: { - ref var parameter = ref _parametersById.Map.AddOrGetValueRef(node.ChildIdx, out var found); + var declarationIndex = node.Is(ExprNodeKind.ParameterUsage) ? node.ChildIdx : index; + ref var parameter = ref _parametersById.Map.AddOrGetValueRef(declarationIndex, out var found); if (found) return parameter; @@ -1454,6 +1584,7 @@ public SysExpr ReadExpression(int index) [RequiresUnreferencedCode(FastExpressionCompiler.LightExpression.Trimming.Message)] private SysSwitchCase ReadSwitchCase(int index) { + index = ResolveReferenceIndex(index); ref var node = ref _tree.Nodes[index]; Debug.Assert(node.Is(ExprNodeKind.SwitchCase)); var children = GetChildren(index); @@ -1466,6 +1597,7 @@ private SysSwitchCase ReadSwitchCase(int index) [RequiresUnreferencedCode(FastExpressionCompiler.LightExpression.Trimming.Message)] private SysCatchBlock ReadCatchBlock(int index) { + index = ResolveReferenceIndex(index); ref var node = ref _tree.Nodes[index]; Debug.Assert(node.Is(ExprNodeKind.CatchBlock)); var children = GetChildren(index); @@ -1478,6 +1610,7 @@ private SysCatchBlock ReadCatchBlock(int index) private SysLabelTarget ReadLabelTarget(int index) { + index = ResolveReferenceIndex(index); ref var node = ref _tree.Nodes[index]; Debug.Assert(node.Is(ExprNodeKind.LabelTarget)); ref var label = ref _labelsById.Map.AddOrGetValueRef(node.ChildIdx, out var found); @@ -1489,6 +1622,7 @@ private SysLabelTarget ReadLabelTarget(int index) private object ReadObjectReference(int index) { + index = ResolveReferenceIndex(index); ref var node = ref _tree.Nodes[index]; Debug.Assert(node.Is(ExprNodeKind.ObjectReference)); return node.Obj; @@ -1496,6 +1630,7 @@ private object ReadObjectReference(int index) private void ReadUInt16Pair(int index, out int first, out int second) { + index = ResolveReferenceIndex(index); ref var node = ref _tree.Nodes[index]; Debug.Assert(node.Is(ExprNodeKind.UInt16Pair)); first = node.ChildIdx; @@ -1505,6 +1640,7 @@ private void ReadUInt16Pair(int index, out int first, out int second) [RequiresUnreferencedCode(FastExpressionCompiler.LightExpression.Trimming.Message)] private SysMemberBinding ReadMemberBinding(int index) { + index = ResolveReferenceIndex(index); ref var node = ref _tree.Nodes[index]; var member = (System.Reflection.MemberInfo)node.Obj; switch (node.Kind) @@ -1535,11 +1671,20 @@ private SysMemberBinding ReadMemberBinding(int index) [RequiresUnreferencedCode(FastExpressionCompiler.LightExpression.Trimming.Message)] private SysElementInit ReadElementInit(int index) { + index = ResolveReferenceIndex(index); ref var node = ref _tree.Nodes[index]; Debug.Assert(node.Is(ExprNodeKind.ElementInit)); return SysExpr.ElementInit((System.Reflection.MethodInfo)node.Obj, ReadExpressions(GetChildren(index))); } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private int ResolveReferenceIndex(int index) + { + while (_tree.Nodes[index].Is(ExprNodeKind.NodeReference)) + index = _tree.Nodes[index].ChildIdx; + return index; + } + private ChildList GetChildren(int index) { ref var node = ref _tree.Nodes[index]; diff --git a/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs b/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs index b91c23c0..134d4ddd 100644 --- a/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs +++ b/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs @@ -486,6 +486,55 @@ public void Can_build_flat_expression_control_flow_directly() Asserts.AreSame(gotoExpr.Target, label.Target); } + public void Flat_expression_splits_parameter_declarations_and_usages_with_scope_metadata() + { + var fe = default(ExprTree); + var p = fe.Parameter(typeof(int), "p"); + var body = fe.Add(p, p); + fe.RootIndex = fe.Lambda>(body, p); + + var declaration = fe.Nodes[p]; + Asserts.AreEqual(ExprNodeKind.ParameterDeclaration, declaration.Kind); + Asserts.AreEqual(fe.RootIndex, declaration.ChildIdx); + Asserts.AreEqual(0, declaration.ChildCount); + + var add = fe.Nodes[body]; + var firstUsageIndex = add.ChildIdx; + var secondUsageIndex = fe.Nodes[firstUsageIndex].NextIdx; + var firstUsage = fe.Nodes[firstUsageIndex]; + var secondUsage = fe.Nodes[secondUsageIndex]; + + Asserts.AreEqual(ExprNodeKind.ParameterUsage, firstUsage.Kind); + Asserts.AreEqual(ExprNodeKind.ParameterUsage, secondUsage.Kind); + Asserts.AreEqual(p, firstUsage.ChildIdx); + Asserts.AreEqual(p, secondUsage.ChildIdx); + Asserts.AreEqual(typeof(int), firstUsage.Type); + Asserts.AreEqual("p", (string)firstUsage.Obj); + } + + public void Flat_expression_links_last_child_to_parent_and_uses_reference_node_on_reuse() + { + var fe = default(ExprTree); + var shared = fe.ConstantInt(42); + var left = fe.Add(shared, fe.ConstantInt(1)); + var right = fe.Add(shared, fe.ConstantInt(2)); + var body = fe.Add(left, right); + fe.RootIndex = fe.Lambda>(body); + + var leftNode = fe.Nodes[left]; + var leftSecondChildIndex = fe.Nodes[leftNode.ChildIdx].NextIdx; + Asserts.AreEqual(left, fe.Nodes[leftSecondChildIndex].NextIdx); + + var rightNode = fe.Nodes[right]; + var rightFirstChild = fe.Nodes[rightNode.ChildIdx]; + Asserts.AreEqual(ExprNodeKind.NodeReference, rightFirstChild.Kind); + Asserts.AreEqual(shared, rightFirstChild.ChildIdx); + var rightSecondChildIndex = fe.Nodes[rightNode.ChildIdx].NextIdx; + Asserts.AreEqual(right, fe.Nodes[rightSecondChildIndex].NextIdx); + + Asserts.AreEqual(87, ((LambdaExpression)fe.ToLightExpression()).CompileFast>(true)()); + } + public class A { public P Prop { get; set; }