diff --git a/.autover/changes/c27a62e6-91ca-4a59-9406-394866cdfa62.json b/.autover/changes/c27a62e6-91ca-4a59-9406-394866cdfa62.json new file mode 100644 index 000000000..39be8933f --- /dev/null +++ b/.autover/changes/c27a62e6-91ca-4a59-9406-394866cdfa62.json @@ -0,0 +1,18 @@ +{ + "Projects": [ + { + "Name": "Amazon.Lambda.RuntimeSupport", + "Type": "Minor", + "ChangelogMessages": [ + "(Preview) Add response streaming support" + ] + }, + { + "Name": "Amazon.Lambda.Core", + "Type": "Minor", + "ChangelogMessages": [ + "(Preview) Add response streaming support" + ] + } + ] +} diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 000000000..62b77bced --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @aws/aws-sdk-dotnet-team diff --git a/.gitignore b/.gitignore index f91715274..1caae6fe4 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,8 @@ *.suo *.user +**/.kiro/ + #################### # Build/Test folders #################### diff --git a/CHANGELOG.md b/CHANGELOG.md index 704ea1265..54300c8e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,23 @@ +## Release 2026-04-14 + +### Amazon.Lambda.TestTool.BlazorTester (0.17.1) +* Minor fixes to improve the testability of the package +### Amazon.Lambda.RuntimeSupport (1.14.3) +* Minor fixes to improve the testability of the package +### Amazon.Lambda.Annotations (1.13.0) +* Added [FunctionUrl] attribute for configuring Lambda functions with Function URL endpoints, including optional CORS support + +## Release 2026-04-13 #2 + +### Amazon.Lambda.Annotations (1.12.0) +* treat warnings as errors and fix unshipped.md +* Added [S3Event] annotation attribute for declaratively configuring S3 event-triggered Lambda functions with support for bucket reference, event types, key prefix/suffix filters, and enabled state. + +## Release 2026-04-08 + +### Amazon.Lambda.Annotations (1.11.0) +* Added [ALBApi] attribute for configuring Lambda functions as targets behind an Application Load Balancer + ## Release 2026-03-27 ### Amazon.Lambda.Annotations (1.10.0) diff --git a/Libraries/Amazon.Lambda.Annotations.slnf b/Libraries/Amazon.Lambda.Annotations.slnf index ecb4e01ee..d0bf67584 100644 --- a/Libraries/Amazon.Lambda.Annotations.slnf +++ b/Libraries/Amazon.Lambda.Annotations.slnf @@ -16,7 +16,10 @@ "test\\TestCustomAuthorizerApp.IntegrationTests\\TestCustomAuthorizerApp.IntegrationTests.csproj", "test\\TestServerlessApp.IntegrationTests\\TestServerlessApp.IntegrationTests.csproj", "test\\TestServerlessApp.NET8\\TestServerlessApp.NET8.csproj", - "test\\TestServerlessApp\\TestServerlessApp.csproj" + "src\\Amazon.Lambda.ApplicationLoadBalancerEvents\\Amazon.Lambda.ApplicationLoadBalancerEvents.csproj", + "test\\TestServerlessApp\\TestServerlessApp.csproj", + "test\\TestServerlessApp.ALB\\TestServerlessApp.ALB.csproj", + "test\\TestServerlessApp.ALB.IntegrationTests\\TestServerlessApp.ALB.IntegrationTests.csproj" ] } } diff --git a/Libraries/Libraries.sln b/Libraries/Libraries.sln index f3214606a..aa4c33d06 100644 --- a/Libraries/Libraries.sln +++ b/Libraries/Libraries.sln @@ -1,7 +1,7 @@  Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio Version 17 -VisualStudioVersion = 17.0.31717.71 +# Visual Studio Version 18 +VisualStudioVersion = 18.3.11512.155 d18.3 MinimumVisualStudioVersion = 10.0.40219.1 Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{AAB54E74-20B1-42ED-BC3D-CE9F7BC7FD12}" EndProject @@ -151,6 +151,12 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TestCustomAuthorizerApp.Int EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TestCustomAuthorizerApp", "test\TestCustomAuthorizerApp\TestCustomAuthorizerApp.csproj", "{3BFA4B73-BA61-4578-833B-C5B3A16EDA9E}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TestServerlessApp.ALB", "test\TestServerlessApp.ALB\TestServerlessApp.ALB.csproj", "{8F7C617D-C611-4DC6-A07C-033F13C1835D}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TestServerlessApp.ALB.IntegrationTests", "test\TestServerlessApp.ALB.IntegrationTests\TestServerlessApp.ALB.IntegrationTests.csproj", "{80594C21-C6EB-469E-83CC-68F9F661CA5E}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ResponseStreamingFunctionHandlers", "test\Amazon.Lambda.RuntimeSupport.Tests\ResponseStreamingFunctionHandlers\ResponseStreamingFunctionHandlers.csproj", "{E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -941,6 +947,42 @@ Global {3BFA4B73-BA61-4578-833B-C5B3A16EDA9E}.Release|x64.Build.0 = Release|Any CPU {3BFA4B73-BA61-4578-833B-C5B3A16EDA9E}.Release|x86.ActiveCfg = Release|Any CPU {3BFA4B73-BA61-4578-833B-C5B3A16EDA9E}.Release|x86.Build.0 = Release|Any CPU + {8F7C617D-C611-4DC6-A07C-033F13C1835D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {8F7C617D-C611-4DC6-A07C-033F13C1835D}.Debug|Any CPU.Build.0 = Debug|Any CPU + {8F7C617D-C611-4DC6-A07C-033F13C1835D}.Debug|x64.ActiveCfg = Debug|Any CPU + {8F7C617D-C611-4DC6-A07C-033F13C1835D}.Debug|x64.Build.0 = Debug|Any CPU + {8F7C617D-C611-4DC6-A07C-033F13C1835D}.Debug|x86.ActiveCfg = Debug|Any CPU + {8F7C617D-C611-4DC6-A07C-033F13C1835D}.Debug|x86.Build.0 = Debug|Any CPU + {8F7C617D-C611-4DC6-A07C-033F13C1835D}.Release|Any CPU.ActiveCfg = Release|Any CPU + {8F7C617D-C611-4DC6-A07C-033F13C1835D}.Release|Any CPU.Build.0 = Release|Any CPU + {8F7C617D-C611-4DC6-A07C-033F13C1835D}.Release|x64.ActiveCfg = Release|Any CPU + {8F7C617D-C611-4DC6-A07C-033F13C1835D}.Release|x64.Build.0 = Release|Any CPU + {8F7C617D-C611-4DC6-A07C-033F13C1835D}.Release|x86.ActiveCfg = Release|Any CPU + {8F7C617D-C611-4DC6-A07C-033F13C1835D}.Release|x86.Build.0 = Release|Any CPU + {80594C21-C6EB-469E-83CC-68F9F661CA5E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {80594C21-C6EB-469E-83CC-68F9F661CA5E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {80594C21-C6EB-469E-83CC-68F9F661CA5E}.Debug|x64.ActiveCfg = Debug|Any CPU + {80594C21-C6EB-469E-83CC-68F9F661CA5E}.Debug|x64.Build.0 = Debug|Any CPU + {80594C21-C6EB-469E-83CC-68F9F661CA5E}.Debug|x86.ActiveCfg = Debug|Any CPU + {80594C21-C6EB-469E-83CC-68F9F661CA5E}.Debug|x86.Build.0 = Debug|Any CPU + {80594C21-C6EB-469E-83CC-68F9F661CA5E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {80594C21-C6EB-469E-83CC-68F9F661CA5E}.Release|Any CPU.Build.0 = Release|Any CPU + {80594C21-C6EB-469E-83CC-68F9F661CA5E}.Release|x64.ActiveCfg = Release|Any CPU + {80594C21-C6EB-469E-83CC-68F9F661CA5E}.Release|x64.Build.0 = Release|Any CPU + {80594C21-C6EB-469E-83CC-68F9F661CA5E}.Release|x86.ActiveCfg = Release|Any CPU + {80594C21-C6EB-469E-83CC-68F9F661CA5E}.Release|x86.Build.0 = Release|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Debug|x64.ActiveCfg = Debug|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Debug|x64.Build.0 = Debug|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Debug|x86.ActiveCfg = Debug|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Debug|x86.Build.0 = Debug|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Release|Any CPU.Build.0 = Release|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Release|x64.ActiveCfg = Release|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Release|x64.Build.0 = Release|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Release|x86.ActiveCfg = Release|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -1015,6 +1057,9 @@ Global {8D03BDF3-7078-4B46-A3F1-C73BE6D6CE0D} = {1DE4EE60-45BA-4EF7-BE00-B9EB861E4C69} {8EEDD576-7FC4-4FAC-A5A2-F58562753A53} = {1DE4EE60-45BA-4EF7-BE00-B9EB861E4C69} {3BFA4B73-BA61-4578-833B-C5B3A16EDA9E} = {1DE4EE60-45BA-4EF7-BE00-B9EB861E4C69} + {8F7C617D-C611-4DC6-A07C-033F13C1835D} = {1DE4EE60-45BA-4EF7-BE00-B9EB861E4C69} + {80594C21-C6EB-469E-83CC-68F9F661CA5E} = {1DE4EE60-45BA-4EF7-BE00-B9EB861E4C69} + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9} = {B5BD0336-7D08-492C-8489-42C987E29B39} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {503678A4-B8D1-4486-8915-405A3E9CF0EB} diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Amazon.Lambda.Annotations.SourceGenerator.csproj b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Amazon.Lambda.Annotations.SourceGenerator.csproj index 3e2dd821e..79a18d2b8 100644 --- a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Amazon.Lambda.Annotations.SourceGenerator.csproj +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Amazon.Lambda.Annotations.SourceGenerator.csproj @@ -20,7 +20,8 @@ true false - 1.10.0 + 1.13.0 + true diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Diagnostics/AnalyzerReleases.Unshipped.md b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Diagnostics/AnalyzerReleases.Unshipped.md index e9b44dd1e..d1a9a89d0 100644 --- a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Diagnostics/AnalyzerReleases.Unshipped.md +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Diagnostics/AnalyzerReleases.Unshipped.md @@ -16,3 +16,8 @@ AWSLambda0128 | AWSLambdaCSharpGenerator | Error | Authorizer Payload Version Mi AWSLambda0129 | AWSLambdaCSharpGenerator | Error | Missing LambdaFunction Attribute AWSLambda0130 | AWSLambdaCSharpGenerator | Error | Invalid return type IAuthorizerResult AWSLambda0131 | AWSLambdaCSharpGenerator | Error | FromBody not supported on Authorizer functions +AWSLambda0132 | AWSLambdaCSharpGenerator | Error | Invalid ALBApiAttribute +AWSLambda0133 | AWSLambdaCSharpGenerator | Error | ALB Listener Reference Not Found +AWSLambda0134 | AWSLambdaCSharpGenerator | Error | FromRoute not supported on ALB functions +AWSLambda0135 | AWSLambdaCSharpGenerator | Error | Unmapped parameter on ALB function +AWSLambda0136 | AWSLambdaCSharpGenerator | Error | Invalid S3EventAttribute diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Diagnostics/DiagnosticDescriptors.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Diagnostics/DiagnosticDescriptors.cs index 69c4f9428..e1a11087f 100644 --- a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Diagnostics/DiagnosticDescriptors.cs +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Diagnostics/DiagnosticDescriptors.cs @@ -242,5 +242,44 @@ public static class DiagnosticDescriptors category: "AWSLambdaCSharpGenerator", DiagnosticSeverity.Error, isEnabledByDefault: true); + + public static readonly DiagnosticDescriptor InvalidAlbApiAttribute = new DiagnosticDescriptor( + id: "AWSLambda0132", + title: "Invalid ALBApiAttribute", + messageFormat: "Invalid ALBApiAttribute encountered: {0}", + category: "AWSLambdaCSharpGenerator", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public static readonly DiagnosticDescriptor AlbListenerReferenceNotFound = new DiagnosticDescriptor( + id: "AWSLambda0133", + title: "ALB Listener Reference Not Found", + messageFormat: "The ALBApi ListenerArn references '@{0}', but no resource or parameter named '{0}' was found in the CloudFormation template. Add the listener resource to the template or correct the reference name.", + category: "AWSLambdaCSharpGenerator", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public static readonly DiagnosticDescriptor FromRouteNotSupportedOnAlb = new DiagnosticDescriptor( + id: "AWSLambda0134", + title: "FromRoute not supported on ALB functions", + messageFormat: "[FromRoute] is not supported on ALB functions. ALB does not support route path template parameters. Use [FromHeader], [FromQuery], or [FromBody] instead.", + category: "AWSLambdaCSharpGenerator", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public static readonly DiagnosticDescriptor AlbUnmappedParameter = new DiagnosticDescriptor( + id: "AWSLambda0135", + title: "Unmapped parameter on ALB function", + messageFormat: "Parameter '{0}' on ALB function has no binding attribute. Use [FromHeader], [FromQuery], [FromBody], or [FromServices], or use the ApplicationLoadBalancerRequest or ILambdaContext types.", + category: "AWSLambdaCSharpGenerator", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public static readonly DiagnosticDescriptor InvalidS3EventAttribute = new DiagnosticDescriptor(id: "AWSLambda0136", + title: "Invalid S3EventAttribute", + messageFormat: "Invalid S3EventAttribute encountered: {0}", + category: "AWSLambdaCSharpGenerator", + DiagnosticSeverity.Error, + isEnabledByDefault: true); } } diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Extensions/ParameterListExtension.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Extensions/ParameterListExtension.cs index 5465f8323..9310019eb 100644 --- a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Extensions/ParameterListExtension.cs +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Extensions/ParameterListExtension.cs @@ -17,6 +17,12 @@ public static bool HasConvertibleParameter(this IList parameters return false; } + // ALB request types are forwarded to lambda method if specified, there is no parameter conversion required. + if (TypeFullNames.ALBRequests.Contains(p.Type.FullName)) + { + return false; + } + // ILambdaContext is forwarded to lambda method if specified, there is no parameter conversion required. if (p.Type.FullName == TypeFullNames.ILambdaContext) { @@ -24,7 +30,7 @@ public static bool HasConvertibleParameter(this IList parameters } // Body parameter with target type as string doesn't require conversion because body is string by nature. - if (p.Attributes.Any(att => att.Type.FullName == TypeFullNames.FromBodyAttribute) && p.Type.IsString()) + if (p.Attributes.Any(att => att.Type.FullName == TypeFullNames.FromBodyAttribute || att.Type.FullName == TypeFullNames.ALBFromBodyAttribute) && p.Type.IsString()) { return false; } diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/Attributes/ALBApiAttributeBuilder.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/Attributes/ALBApiAttributeBuilder.cs new file mode 100644 index 000000000..d64f64048 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/Attributes/ALBApiAttributeBuilder.cs @@ -0,0 +1,68 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using Amazon.Lambda.Annotations.ALB; +using Microsoft.CodeAnalysis; +using System; +using System.Linq; + +namespace Amazon.Lambda.Annotations.SourceGenerator.Models.Attributes +{ + /// + /// Builder for . + /// + public class ALBApiAttributeBuilder + { + public static ALBApiAttribute Build(AttributeData att) + { + if (att.ConstructorArguments.Length != 3) + { + throw new NotSupportedException($"{TypeFullNames.ALBApiAttribute} must have constructor with 3 arguments."); + } + + var listenerArn = att.ConstructorArguments[0].Value as string; + var pathPattern = att.ConstructorArguments[1].Value as string; + var priority = (int)att.ConstructorArguments[2].Value; + + var data = new ALBApiAttribute(listenerArn, pathPattern, priority); + + foreach (var pair in att.NamedArguments) + { + if (pair.Key == nameof(data.MultiValueHeaders) && pair.Value.Value is bool multiValueHeaders) + { + data.MultiValueHeaders = multiValueHeaders; + } + else if (pair.Key == nameof(data.HostHeader) && pair.Value.Value is string hostHeader) + { + data.HostHeader = hostHeader; + } + else if (pair.Key == nameof(data.HttpMethod) && pair.Value.Value is string httpMethod) + { + data.HttpMethod = httpMethod; + } + else if (pair.Key == nameof(data.ResourceName) && pair.Value.Value is string resourceName) + { + data.ResourceName = resourceName; + } + else if (pair.Key == nameof(data.HttpHeaderConditionName) && pair.Value.Value is string httpHeaderConditionName) + { + data.HttpHeaderConditionName = httpHeaderConditionName; + } + else if (pair.Key == nameof(data.HttpHeaderConditionValues) && !pair.Value.IsNull) + { + data.HttpHeaderConditionValues = pair.Value.Values.Select(v => v.Value as string).ToArray(); + } + else if (pair.Key == nameof(data.QueryStringConditions) && !pair.Value.IsNull) + { + data.QueryStringConditions = pair.Value.Values.Select(v => v.Value as string).ToArray(); + } + else if (pair.Key == nameof(data.SourceIpConditions) && !pair.Value.IsNull) + { + data.SourceIpConditions = pair.Value.Values.Select(v => v.Value as string).ToArray(); + } + } + + return data; + } + } +} diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/Attributes/ALBFromHeaderAttributeBuilder.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/Attributes/ALBFromHeaderAttributeBuilder.cs new file mode 100644 index 000000000..a0ca9aced --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/Attributes/ALBFromHeaderAttributeBuilder.cs @@ -0,0 +1,28 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using Amazon.Lambda.Annotations.ALB; +using Microsoft.CodeAnalysis; + +namespace Amazon.Lambda.Annotations.SourceGenerator.Models.Attributes +{ + /// + /// Builder for . + /// + public class ALBFromHeaderAttributeBuilder + { + public static ALB.FromHeaderAttribute Build(AttributeData att) + { + var data = new ALB.FromHeaderAttribute(); + foreach (var pair in att.NamedArguments) + { + if (pair.Key == nameof(data.Name) && pair.Value.Value is string value) + { + data.Name = value; + } + } + + return data; + } + } +} diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/Attributes/ALBFromQueryAttributeBuilder.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/Attributes/ALBFromQueryAttributeBuilder.cs new file mode 100644 index 000000000..8fb7ce644 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/Attributes/ALBFromQueryAttributeBuilder.cs @@ -0,0 +1,28 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using Amazon.Lambda.Annotations.ALB; +using Microsoft.CodeAnalysis; + +namespace Amazon.Lambda.Annotations.SourceGenerator.Models.Attributes +{ + /// + /// Builder for . + /// + public class ALBFromQueryAttributeBuilder + { + public static ALB.FromQueryAttribute Build(AttributeData att) + { + var data = new ALB.FromQueryAttribute(); + foreach (var pair in att.NamedArguments) + { + if (pair.Key == nameof(data.Name) && pair.Value.Value is string value) + { + data.Name = value; + } + } + + return data; + } + } +} diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/Attributes/AttributeModelBuilder.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/Attributes/AttributeModelBuilder.cs index 328a29ac5..d8715c047 100644 --- a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/Attributes/AttributeModelBuilder.cs +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/Attributes/AttributeModelBuilder.cs @@ -1,5 +1,10 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + using System; +using Amazon.Lambda.Annotations.ALB; using Amazon.Lambda.Annotations.APIGateway; +using Amazon.Lambda.Annotations.S3; using Amazon.Lambda.Annotations.SQS; using Microsoft.CodeAnalysis; @@ -30,7 +35,7 @@ public static AttributeModel Build(AttributeData att, GeneratorExecutionContext else if (att.AttributeClass.Equals(context.Compilation.GetTypeByMetadataName(TypeFullNames.FromQueryAttribute), SymbolEqualityComparer.Default)) { var data = FromQueryAttributeBuilder.Build(att); - model = new AttributeModel + model = new AttributeModel { Data = data, Type = TypeModelBuilder.Build(att.AttributeClass, context) @@ -39,7 +44,7 @@ public static AttributeModel Build(AttributeData att, GeneratorExecutionContext else if (att.AttributeClass.Equals(context.Compilation.GetTypeByMetadataName(TypeFullNames.FromHeaderAttribute), SymbolEqualityComparer.Default)) { var data = FromHeaderAttributeBuilder.Build(att); - model = new AttributeModel + model = new AttributeModel { Data = data, Type = TypeModelBuilder.Build(att.AttributeClass, context) @@ -90,6 +95,24 @@ public static AttributeModel Build(AttributeData att, GeneratorExecutionContext Type = TypeModelBuilder.Build(att.AttributeClass, context) }; } + else if (att.AttributeClass.Equals(context.Compilation.GetTypeByMetadataName(TypeFullNames.S3EventAttribute), SymbolEqualityComparer.Default)) + { + var data = S3EventAttributeBuilder.Build(att); + model = new AttributeModel + { + Data = data, + Type = TypeModelBuilder.Build(att.AttributeClass, context) + }; + } + else if (att.AttributeClass.Equals(context.Compilation.GetTypeByMetadataName(TypeFullNames.FunctionUrlAttribute), SymbolEqualityComparer.Default)) + { + var data = FunctionUrlAttributeBuilder.Build(att); + model = new AttributeModel + { + Data = data, + Type = TypeModelBuilder.Build(att.AttributeClass, context) + }; + } else if (att.AttributeClass.Equals(context.Compilation.GetTypeByMetadataName(TypeFullNames.HttpApiAuthorizerAttribute), SymbolEqualityComparer.Default)) { var data = HttpApiAuthorizerAttributeBuilder.Build(att); @@ -108,6 +131,42 @@ public static AttributeModel Build(AttributeData att, GeneratorExecutionContext Type = TypeModelBuilder.Build(att.AttributeClass, context) }; } + else if (att.AttributeClass.Equals(context.Compilation.GetTypeByMetadataName(TypeFullNames.ALBApiAttribute), SymbolEqualityComparer.Default)) + { + var data = ALBApiAttributeBuilder.Build(att); + model = new AttributeModel + { + Data = data, + Type = TypeModelBuilder.Build(att.AttributeClass, context) + }; + } + else if (att.AttributeClass.Equals(context.Compilation.GetTypeByMetadataName(TypeFullNames.ALBFromQueryAttribute), SymbolEqualityComparer.Default)) + { + var data = ALBFromQueryAttributeBuilder.Build(att); + model = new AttributeModel + { + Data = data, + Type = TypeModelBuilder.Build(att.AttributeClass, context) + }; + } + else if (att.AttributeClass.Equals(context.Compilation.GetTypeByMetadataName(TypeFullNames.ALBFromHeaderAttribute), SymbolEqualityComparer.Default)) + { + var data = ALBFromHeaderAttributeBuilder.Build(att); + model = new AttributeModel + { + Data = data, + Type = TypeModelBuilder.Build(att.AttributeClass, context) + }; + } + else if (att.AttributeClass.Equals(context.Compilation.GetTypeByMetadataName(TypeFullNames.ALBFromBodyAttribute), SymbolEqualityComparer.Default)) + { + var data = new ALB.FromBodyAttribute(); + model = new AttributeModel + { + Data = data, + Type = TypeModelBuilder.Build(att.AttributeClass, context) + }; + } else { model = new AttributeModel @@ -119,4 +178,4 @@ public static AttributeModel Build(AttributeData att, GeneratorExecutionContext return model; } } -} \ No newline at end of file +} diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/Attributes/FunctionUrlAttributeBuilder.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/Attributes/FunctionUrlAttributeBuilder.cs new file mode 100644 index 000000000..48bb69ea8 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/Attributes/FunctionUrlAttributeBuilder.cs @@ -0,0 +1,48 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using System.Linq; +using Amazon.Lambda.Annotations.APIGateway; +using Microsoft.CodeAnalysis; + +namespace Amazon.Lambda.Annotations.SourceGenerator.Models.Attributes +{ + public static class FunctionUrlAttributeBuilder + { + public static FunctionUrlAttribute Build(AttributeData att) + { + var authType = att.NamedArguments.FirstOrDefault(arg => arg.Key == "AuthType").Value.Value; + + var data = new FunctionUrlAttribute + { + AuthType = authType == null ? FunctionUrlAuthType.NONE : (FunctionUrlAuthType)authType + }; + + var allowOrigins = att.NamedArguments.FirstOrDefault(arg => arg.Key == "AllowOrigins").Value; + if (!allowOrigins.IsNull) + data.AllowOrigins = allowOrigins.Values.Select(v => v.Value as string).ToArray(); + + var allowMethods = att.NamedArguments.FirstOrDefault(arg => arg.Key == "AllowMethods").Value; + if (!allowMethods.IsNull) + data.AllowMethods = allowMethods.Values.Select(v => (LambdaHttpMethod)(int)v.Value).ToArray(); + + var allowHeaders = att.NamedArguments.FirstOrDefault(arg => arg.Key == "AllowHeaders").Value; + if (!allowHeaders.IsNull) + data.AllowHeaders = allowHeaders.Values.Select(v => v.Value as string).ToArray(); + + var exposeHeaders = att.NamedArguments.FirstOrDefault(arg => arg.Key == "ExposeHeaders").Value; + if (!exposeHeaders.IsNull) + data.ExposeHeaders = exposeHeaders.Values.Select(v => v.Value as string).ToArray(); + + var allowCredentials = att.NamedArguments.FirstOrDefault(arg => arg.Key == "AllowCredentials").Value.Value; + if (allowCredentials != null) + data.AllowCredentials = (bool)allowCredentials; + + var maxAge = att.NamedArguments.FirstOrDefault(arg => arg.Key == "MaxAge").Value.Value; + if (maxAge != null) + data.MaxAge = (int)maxAge; + + return data; + } + } +} diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/Attributes/S3EventAttributeBuilder.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/Attributes/S3EventAttributeBuilder.cs new file mode 100644 index 000000000..66070af74 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/Attributes/S3EventAttributeBuilder.cs @@ -0,0 +1,37 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using Amazon.Lambda.Annotations.S3; +using Microsoft.CodeAnalysis; +using System; + +namespace Amazon.Lambda.Annotations.SourceGenerator.Models.Attributes +{ + public class S3EventAttributeBuilder + { + public static S3EventAttribute Build(AttributeData att) + { + if (att.ConstructorArguments.Length != 1) + throw new NotSupportedException($"{TypeFullNames.S3EventAttribute} must have constructor with 1 argument."); + + var bucket = att.ConstructorArguments[0].Value as string; + var data = new S3EventAttribute(bucket); + + foreach (var pair in att.NamedArguments) + { + if (pair.Key == nameof(data.ResourceName) && pair.Value.Value is string resourceName) + data.ResourceName = resourceName; + else if (pair.Key == nameof(data.Events) && pair.Value.Value is string events) + data.Events = events; + else if (pair.Key == nameof(data.FilterPrefix) && pair.Value.Value is string filterPrefix) + data.FilterPrefix = filterPrefix; + else if (pair.Key == nameof(data.FilterSuffix) && pair.Value.Value is string filterSuffix) + data.FilterSuffix = filterSuffix; + else if (pair.Key == nameof(data.Enabled) && pair.Value.Value is bool enabled) + data.Enabled = enabled; + } + + return data; + } + } +} diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/EventType.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/EventType.cs index d231967e3..1b392572d 100644 --- a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/EventType.cs +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/EventType.cs @@ -11,6 +11,7 @@ public enum EventType SQS, DynamoDB, Schedule, - Authorizer + Authorizer, + ALB } } \ No newline at end of file diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/EventTypeBuilder.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/EventTypeBuilder.cs index 3f5775851..d3c1f7fd0 100644 --- a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/EventTypeBuilder.cs +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/EventTypeBuilder.cs @@ -1,3 +1,6 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + using System; using System.Collections.Generic; using System.Linq; @@ -18,7 +21,8 @@ public static HashSet Build(IMethodSymbol lambdaMethodSymbol, foreach (var attribute in lambdaMethodSymbol.GetAttributes()) { if (attribute.AttributeClass.ToDisplayString() == TypeFullNames.RestApiAttribute - || attribute.AttributeClass.ToDisplayString() == TypeFullNames.HttpApiAttribute) + || attribute.AttributeClass.ToDisplayString() == TypeFullNames.HttpApiAttribute + || attribute.AttributeClass.ToDisplayString() == TypeFullNames.FunctionUrlAttribute) { events.Add(EventType.API); } @@ -26,11 +30,19 @@ public static HashSet Build(IMethodSymbol lambdaMethodSymbol, { events.Add(EventType.SQS); } + else if (attribute.AttributeClass.ToDisplayString() == TypeFullNames.S3EventAttribute) + { + events.Add(EventType.S3); + } else if (attribute.AttributeClass.ToDisplayString() == TypeFullNames.HttpApiAuthorizerAttribute || attribute.AttributeClass.ToDisplayString() == TypeFullNames.RestApiAuthorizerAttribute) { events.Add(EventType.Authorizer); } + else if (attribute.AttributeClass.ToDisplayString() == TypeFullNames.ALBApiAttribute) + { + events.Add(EventType.ALB); + } } return events; diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/GeneratedMethodModelBuilder.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/GeneratedMethodModelBuilder.cs index decb864ee..e3c6a020e 100644 --- a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/GeneratedMethodModelBuilder.cs +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/GeneratedMethodModelBuilder.cs @@ -1,3 +1,6 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + using System; using System.Collections.Generic; using System.Linq; @@ -130,6 +133,28 @@ private static TypeModel BuildResponseType(IMethodSymbol lambdaMethodSymbol, throw new ArgumentOutOfRangeException(); } } + else if (lambdaMethodSymbol.HasAttribute(context, TypeFullNames.ALBApiAttribute)) + { + // ALB functions return ApplicationLoadBalancerResponse + // If the user already returns ApplicationLoadBalancerResponse, pass through the return type. + // Otherwise, wrap in ApplicationLoadBalancerResponse. + if (lambdaMethodModel.ReturnsApplicationLoadBalancerResponse) + { + return lambdaMethodModel.ReturnType; + } + var symbol = lambdaMethodModel.ReturnsVoidOrGenericTask ? + task.Construct(context.Compilation.GetTypeByMetadataName(TypeFullNames.ApplicationLoadBalancerResponse)): + context.Compilation.GetTypeByMetadataName(TypeFullNames.ApplicationLoadBalancerResponse); + return TypeModelBuilder.Build(symbol, context); + } + else if (lambdaMethodSymbol.HasAttribute(context, TypeFullNames.FunctionUrlAttribute)) + { + // Function URLs use the same payload format as HTTP API v2 + var symbol = lambdaMethodModel.ReturnsVoidOrGenericTask ? + task.Construct(context.Compilation.GetTypeByMetadataName(TypeFullNames.APIGatewayHttpApiV2ProxyResponse)): + context.Compilation.GetTypeByMetadataName(TypeFullNames.APIGatewayHttpApiV2ProxyResponse); + return TypeModelBuilder.Build(symbol, context); + } else { return lambdaMethodModel.ReturnType; @@ -277,6 +302,33 @@ private static IList BuildParameters(IMethodSymbol lambdaMethodS parameters.Add(requestParameter); parameters.Add(contextParameter); } + else if (lambdaMethodSymbol.HasAttribute(context, TypeFullNames.ALBApiAttribute)) + { + var symbol = context.Compilation.GetTypeByMetadataName(TypeFullNames.ApplicationLoadBalancerRequest); + var type = TypeModelBuilder.Build(symbol, context); + var requestParameter = new ParameterModel + { + Name = "__request__", + Type = type, + Documentation = "The ALB request object that will be processed by the Lambda function handler." + }; + parameters.Add(requestParameter); + parameters.Add(contextParameter); + } + else if (lambdaMethodSymbol.HasAttribute(context, TypeFullNames.FunctionUrlAttribute)) + { + // Function URLs use the same payload format as HTTP API v2 + var symbol = context.Compilation.GetTypeByMetadataName(TypeFullNames.APIGatewayHttpApiV2ProxyRequest); + var type = TypeModelBuilder.Build(symbol, context); + var requestParameter = new ParameterModel + { + Name = "__request__", + Type = type, + Documentation = "The Function URL request object that will be processed by the Lambda function handler." + }; + parameters.Add(requestParameter); + parameters.Add(contextParameter); + } else { // Lambda method with no event attribute are plain lambda functions, therefore, generated method will have diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/LambdaMethodModel.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/LambdaMethodModel.cs index df80c43e5..601e4d86e 100644 --- a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/LambdaMethodModel.cs +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Models/LambdaMethodModel.cs @@ -89,6 +89,31 @@ public bool ReturnsIAuthorizerResult } } + /// + /// Returns true if the Lambda function returns either ApplicationLoadBalancerResponse or Task<ApplicationLoadBalancerResponse> + /// + public bool ReturnsApplicationLoadBalancerResponse + { + get + { + if (ReturnsVoid) + { + return false; + } + + if (ReturnType.FullName == TypeFullNames.ApplicationLoadBalancerResponse) + { + return true; + } + if (ReturnsGenericTask && ReturnType.TypeArguments.Count == 1 && ReturnType.TypeArguments[0].FullName == TypeFullNames.ApplicationLoadBalancerResponse) + { + return true; + } + + return false; + } + } + /// /// Returns true if the Lambda function returns either void, Task, SQSBatchResponse or Task /// diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/SyntaxReceiver.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/SyntaxReceiver.cs index a5d7ce9ab..230525edd 100644 --- a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/SyntaxReceiver.cs +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/SyntaxReceiver.cs @@ -1,4 +1,7 @@ -using System; +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using System; using System.Collections.Generic; using System.Linq; using Amazon.Lambda.Annotations.SourceGenerator.FileIO; @@ -21,7 +24,10 @@ internal class SyntaxReceiver : ISyntaxContextReceiver { "RestApiAuthorizerAttribute", "RestApiAuthorizer" }, { "HttpApiAttribute", "HttpApi" }, { "RestApiAttribute", "RestApi" }, - { "SQSEventAttribute", "SQSEvent" } + { "FunctionUrlAttribute", "FunctionUrl" }, + { "SQSEventAttribute", "SQSEvent" }, + { "ALBApiAttribute", "ALBApi" }, + { "S3EventAttribute", "S3Event" } }; public List LambdaMethods { get; } = new List(); @@ -120,4 +126,4 @@ public void OnVisitSyntaxNode(GeneratorSyntaxContext context) } } } -} \ No newline at end of file +} diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/ALBInvoke.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/ALBInvoke.cs new file mode 100644 index 000000000..a09bec840 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/ALBInvoke.cs @@ -0,0 +1,420 @@ +// ------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +// ------------------------------------------------------------------------------ +namespace Amazon.Lambda.Annotations.SourceGenerator.Templates +{ + using System.Linq; + using System.Text; + using System.Collections.Generic; + using Amazon.Lambda.Annotations.SourceGenerator.Extensions; + using Amazon.Lambda.Annotations.SourceGenerator.Models; + using Amazon.Lambda.Annotations.SourceGenerator.Models.Attributes; + using System; + + /// + /// Class to produce the template output + /// + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "18.0.0.0")] + public partial class ALBInvoke : ALBInvokeBase + { + /// + /// Create the template output + /// + public virtual string TransformText() + { + + if (_model.GeneratedMethod.ReturnType.FullName == _model.LambdaMethod.ReturnType.FullName) + { + // User already returns ApplicationLoadBalancerResponse (or Task), + // just pass through. + if (_model.LambdaMethod.ReturnsVoid) + { + + this.Write(" "); + this.Write(this.ToStringHelper.ToStringWithCulture(_model.LambdaMethod.ContainingType.Name.ToCamelCase())); + this.Write("."); + this.Write(this.ToStringHelper.ToStringWithCulture(_model.LambdaMethod.Name)); + this.Write("("); + this.Write(this.ToStringHelper.ToStringWithCulture(_parameterSignature)); + this.Write(");\r\n"); + + } + else if (_model.LambdaMethod.ReturnsVoidTask) + { + + this.Write(" await "); + this.Write(this.ToStringHelper.ToStringWithCulture(_model.LambdaMethod.ContainingType.Name.ToCamelCase())); + this.Write("."); + this.Write(this.ToStringHelper.ToStringWithCulture(_model.LambdaMethod.Name)); + this.Write("("); + this.Write(this.ToStringHelper.ToStringWithCulture(_parameterSignature)); + this.Write(");\r\n"); + + } + else + { + + this.Write(" var response = "); + this.Write(this.ToStringHelper.ToStringWithCulture(_model.LambdaMethod.ReturnsGenericTask ? "await " : "")); + this.Write(this.ToStringHelper.ToStringWithCulture(_model.LambdaMethod.ContainingType.Name.ToCamelCase())); + this.Write("."); + this.Write(this.ToStringHelper.ToStringWithCulture(_model.LambdaMethod.Name)); + this.Write("("); + this.Write(this.ToStringHelper.ToStringWithCulture(_parameterSignature)); + this.Write(");\r\n return response;\r\n"); + + } + } + else + { + // User returns a non-ALB type, we need to wrap in ApplicationLoadBalancerResponse + if (_model.LambdaMethod.ReturnsVoid) + { + + this.Write(" "); + this.Write(this.ToStringHelper.ToStringWithCulture(_model.LambdaMethod.ContainingType.Name.ToCamelCase())); + this.Write("."); + this.Write(this.ToStringHelper.ToStringWithCulture(_model.LambdaMethod.Name)); + this.Write("("); + this.Write(this.ToStringHelper.ToStringWithCulture(_parameterSignature)); + this.Write(");\r\n"); + + } + else if (_model.LambdaMethod.ReturnsVoidTask) + { + + this.Write(" await "); + this.Write(this.ToStringHelper.ToStringWithCulture(_model.LambdaMethod.ContainingType.Name.ToCamelCase())); + this.Write("."); + this.Write(this.ToStringHelper.ToStringWithCulture(_model.LambdaMethod.Name)); + this.Write("("); + this.Write(this.ToStringHelper.ToStringWithCulture(_parameterSignature)); + this.Write(");\r\n"); + + } + else + { + + this.Write(" var response = "); + this.Write(this.ToStringHelper.ToStringWithCulture(_model.LambdaMethod.ReturnsGenericTask ? "await " : "")); + this.Write(this.ToStringHelper.ToStringWithCulture(_model.LambdaMethod.ContainingType.Name.ToCamelCase())); + this.Write("."); + this.Write(this.ToStringHelper.ToStringWithCulture(_model.LambdaMethod.Name)); + this.Write("("); + this.Write(this.ToStringHelper.ToStringWithCulture(_parameterSignature)); + this.Write(");\r\n"); + + if (_model.LambdaMethod.ReturnType.IsValueType) + { + + this.Write("\r\n var body = response.ToString();\r\n"); + + } + else if (_model.LambdaMethod.ReturnType.IsString()) + { + // no action needed, response is already a string + } + else + { + + this.Write(" var memoryStream = new MemoryStream();\r\n" + + " serializer.Serialize(response, memoryStream);\r\n" + + " memoryStream.Position = 0;\r\n\r\n" + + " // convert stream to string\r\n" + + " StreamReader reader = new StreamReader( memoryStream );\r\n" + + " var body = reader.ReadToEnd();\r\n"); + + } + } + + this.Write("\r\n return new Amazon.Lambda.ApplicationLoadBalancerEvents.ApplicationLoadBalancerResponse\r\n {\r\n"); + + if (!_model.LambdaMethod.ReturnsVoid && !_model.LambdaMethod.ReturnsVoidTask) + { + + this.Write(" Body = "); + this.Write(this.ToStringHelper.ToStringWithCulture(_model.LambdaMethod.ReturnType.IsString() ? "response" : "body")); + this.Write(",\r\n Headers = new Dictionary\r\n {\r\n {\"Content-Type\", "); + this.Write(this.ToStringHelper.ToStringWithCulture(_model.LambdaMethod.ReturnType.IsString() ? "\"text/plain\"" : "\"application/json\"")); + this.Write("}\r\n },\r\n"); + + } + + this.Write(" StatusCode = 200\r\n };\r\n"); + + } + + return this.GenerationEnvironment.ToString(); + } + } + + #region Base class + /// + /// Base class for this transformation + /// + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "18.0.0.0")] + public class ALBInvokeBase + { + #region Fields + private global::System.Text.StringBuilder generationEnvironmentField; + private global::System.CodeDom.Compiler.CompilerErrorCollection errorsField; + private global::System.Collections.Generic.List indentLengthsField; + private string currentIndentField = ""; + private bool endsWithNewline; + private global::System.Collections.Generic.IDictionary sessionField; + #endregion + #region Properties + /// + /// The string builder that generation-time code is using to assemble generated output + /// + public System.Text.StringBuilder GenerationEnvironment + { + get + { + if ((this.generationEnvironmentField == null)) + { + this.generationEnvironmentField = new global::System.Text.StringBuilder(); + } + return this.generationEnvironmentField; + } + set + { + this.generationEnvironmentField = value; + } + } + /// + /// The error collection for the generation process + /// + public System.CodeDom.Compiler.CompilerErrorCollection Errors + { + get + { + if ((this.errorsField == null)) + { + this.errorsField = new global::System.CodeDom.Compiler.CompilerErrorCollection(); + } + return this.errorsField; + } + } + /// + /// A list of the lengths of each indent that was added with PushIndent + /// + private System.Collections.Generic.List indentLengths + { + get + { + if ((this.indentLengthsField == null)) + { + this.indentLengthsField = new global::System.Collections.Generic.List(); + } + return this.indentLengthsField; + } + } + /// + /// Gets the current indent we use when adding lines to the output + /// + public string CurrentIndent + { + get + { + return this.currentIndentField; + } + } + /// + /// Current transformation session + /// + public virtual global::System.Collections.Generic.IDictionary Session + { + get + { + return this.sessionField; + } + set + { + this.sessionField = value; + } + } + #endregion + #region Transform-time helpers + /// + /// Write text directly into the generated output + /// + public void Write(string textToAppend) + { + if (string.IsNullOrEmpty(textToAppend)) + { + return; + } + if (((this.GenerationEnvironment.Length == 0) + || this.endsWithNewline)) + { + this.GenerationEnvironment.Append(this.currentIndentField); + this.endsWithNewline = false; + } + if (textToAppend.EndsWith(global::System.Environment.NewLine, global::System.StringComparison.CurrentCulture)) + { + this.endsWithNewline = true; + } + if ((this.currentIndentField.Length == 0)) + { + this.GenerationEnvironment.Append(textToAppend); + return; + } + textToAppend = textToAppend.Replace(global::System.Environment.NewLine, (global::System.Environment.NewLine + this.currentIndentField)); + if (this.endsWithNewline) + { + this.GenerationEnvironment.Append(textToAppend, 0, (textToAppend.Length - this.currentIndentField.Length)); + } + else + { + this.GenerationEnvironment.Append(textToAppend); + } + } + /// + /// Write text directly into the generated output + /// + public void WriteLine(string textToAppend) + { + this.Write(textToAppend); + this.GenerationEnvironment.AppendLine(); + this.endsWithNewline = true; + } + /// + /// Write formatted text directly into the generated output + /// + public void Write(string format, params object[] args) + { + this.Write(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args)); + } + /// + /// Write formatted text directly into the generated output + /// + public void WriteLine(string format, params object[] args) + { + this.WriteLine(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args)); + } + /// + /// Raise an error + /// + public void Error(string message) + { + System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError(); + error.ErrorText = message; + this.Errors.Add(error); + } + /// + /// Raise a warning + /// + public void Warning(string message) + { + System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError(); + error.ErrorText = message; + error.IsWarning = true; + this.Errors.Add(error); + } + /// + /// Increase the indent + /// + public void PushIndent(string indent) + { + if ((indent == null)) + { + throw new global::System.ArgumentNullException("indent"); + } + this.currentIndentField = (this.currentIndentField + indent); + this.indentLengths.Add(indent.Length); + } + /// + /// Remove the last indent that was added with PushIndent + /// + public string PopIndent() + { + string returnValue = ""; + if ((this.indentLengths.Count > 0)) + { + int indentLength = this.indentLengths[(this.indentLengths.Count - 1)]; + this.indentLengths.RemoveAt((this.indentLengths.Count - 1)); + if ((indentLength > 0)) + { + returnValue = this.currentIndentField.Substring((this.currentIndentField.Length - indentLength)); + this.currentIndentField = this.currentIndentField.Remove((this.currentIndentField.Length - indentLength)); + } + } + return returnValue; + } + /// + /// Remove any indentation + /// + public void ClearIndent() + { + this.indentLengths.Clear(); + this.currentIndentField = ""; + } + #endregion + #region ToString Helpers + /// + /// Utility class to produce culture-oriented representation of an object as a string. + /// + public class ToStringInstanceHelper + { + private System.IFormatProvider formatProviderField = global::System.Globalization.CultureInfo.InvariantCulture; + /// + /// Gets or sets format provider to be used by ToStringWithCulture method. + /// + public System.IFormatProvider FormatProvider + { + get + { + return this.formatProviderField ; + } + set + { + if ((value != null)) + { + this.formatProviderField = value; + } + } + } + /// + /// This is called from the compile/run appdomain to convert objects within an expression block to a string + /// + public string ToStringWithCulture(object objectToConvert) + { + if ((objectToConvert == null)) + { + throw new global::System.ArgumentNullException("objectToConvert"); + } + System.Type t = objectToConvert.GetType(); + System.Reflection.MethodInfo method = t.GetMethod("ToString", new System.Type[] { + typeof(System.IFormatProvider)}); + if ((method == null)) + { + return objectToConvert.ToString(); + } + else + { + return ((string)(method.Invoke(objectToConvert, new object[] { + this.formatProviderField }))); + } + } + } + private ToStringInstanceHelper toStringHelperField = new ToStringInstanceHelper(); + /// + /// Helper to produce culture-oriented representation of an object as a string + /// + public ToStringInstanceHelper ToStringHelper + { + get + { + return this.toStringHelperField; + } + } + #endregion + } + #endregion +} diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/ALBInvoke.tt b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/ALBInvoke.tt new file mode 100644 index 000000000..e4a8a32fb --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/ALBInvoke.tt @@ -0,0 +1,98 @@ +<#@ template language="C#" #> +<#@ assembly name="System.Core" #> +<#@ import namespace="System.Linq" #> +<#@ import namespace="System.Text" #> +<#@ import namespace="System.Collections.Generic" #> +<#@ import namespace="Amazon.Lambda.Annotations.SourceGenerator.Extensions" #> +<#@ import namespace="Amazon.Lambda.Annotations.SourceGenerator.Models" #> +<#@ import namespace="Amazon.Lambda.Annotations.SourceGenerator.Models.Attributes" #> +<# + if (_model.GeneratedMethod.ReturnType.FullName == _model.LambdaMethod.ReturnType.FullName) + { + // User already returns ApplicationLoadBalancerResponse (or Task), + // just pass through. + if (_model.LambdaMethod.ReturnsVoid) + { +#> + <#= _model.LambdaMethod.ContainingType.Name.ToCamelCase() #>.<#= _model.LambdaMethod.Name #>(<#= _parameterSignature #>); +<# + } + else if (_model.LambdaMethod.ReturnsVoidTask) + { +#> + await <#= _model.LambdaMethod.ContainingType.Name.ToCamelCase() #>.<#= _model.LambdaMethod.Name #>(<#= _parameterSignature #>); +<# + } + else + { +#> + var response = <#= _model.LambdaMethod.ReturnsGenericTask ? "await " : "" #><#= _model.LambdaMethod.ContainingType.Name.ToCamelCase() #>.<#= _model.LambdaMethod.Name #>(<#= _parameterSignature #>); + return response; +<# + } + } + else + { + // User returns a non-ALB type, we need to wrap in ApplicationLoadBalancerResponse + if (_model.LambdaMethod.ReturnsVoid) + { +#> + <#= _model.LambdaMethod.ContainingType.Name.ToCamelCase() #>.<#= _model.LambdaMethod.Name #>(<#= _parameterSignature #>); +<# + } + else if (_model.LambdaMethod.ReturnsVoidTask) + { +#> + await <#= _model.LambdaMethod.ContainingType.Name.ToCamelCase() #>.<#= _model.LambdaMethod.Name #>(<#= _parameterSignature #>); +<# + } + else + { +#> + var response = <#= _model.LambdaMethod.ReturnsGenericTask ? "await " : "" #><#= _model.LambdaMethod.ContainingType.Name.ToCamelCase() #>.<#= _model.LambdaMethod.Name #>(<#= _parameterSignature #>); +<# + if (_model.LambdaMethod.ReturnType.IsValueType) + { +#> + + var body = response.ToString(); +<# + } + else if (_model.LambdaMethod.ReturnType.IsString()) + { + // no action needed, response is already a string + } + else + { +#> + var memoryStream = new MemoryStream(); + serializer.Serialize(response, memoryStream); + memoryStream.Position = 0; + + // convert stream to string + StreamReader reader = new StreamReader( memoryStream ); + var body = reader.ReadToEnd(); +<# + } + } +#> + + return new Amazon.Lambda.ApplicationLoadBalancerEvents.ApplicationLoadBalancerResponse + { +<# + if (!_model.LambdaMethod.ReturnsVoid && !_model.LambdaMethod.ReturnsVoidTask) + { +#> + Body = <#= _model.LambdaMethod.ReturnType.IsString() ? "response" : "body" #>, + Headers = new Dictionary + { + {"Content-Type", <#= _model.LambdaMethod.ReturnType.IsString() ? "\"text/plain\"" : "\"application/json\"" #>} + }, +<# + } +#> + StatusCode = 200 + }; +<# + } +#> diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/ALBInvokeCode.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/ALBInvokeCode.cs new file mode 100644 index 000000000..04076566c --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/ALBInvokeCode.cs @@ -0,0 +1,20 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using Amazon.Lambda.Annotations.SourceGenerator.Models; + +namespace Amazon.Lambda.Annotations.SourceGenerator.Templates +{ + public partial class ALBInvoke + { + private readonly LambdaFunctionModel _model; + + public readonly string _parameterSignature; + + public ALBInvoke(LambdaFunctionModel model, string parameterSignature) + { + _model = model; + _parameterSignature = parameterSignature; + } + } +} diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/ALBSetupParameters.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/ALBSetupParameters.cs new file mode 100644 index 000000000..a6ce865bc --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/ALBSetupParameters.cs @@ -0,0 +1,604 @@ +// ------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +// ------------------------------------------------------------------------------ +namespace Amazon.Lambda.Annotations.SourceGenerator.Templates +{ + using System.Linq; + using System.Text; + using System.Collections.Generic; + using Amazon.Lambda.Annotations.SourceGenerator.Extensions; + using Amazon.Lambda.Annotations.SourceGenerator.Models; + using Amazon.Lambda.Annotations.SourceGenerator.Models.Attributes; + using System; + + /// + /// Class to produce the template output + /// + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "18.0.0.0")] + public partial class ALBSetupParameters : ALBSetupParametersBase + { + /// + /// Create the template output + /// + public virtual string TransformText() + { + + ParameterSignature = string.Join(", ", _model.LambdaMethod.Parameters + .Select(p => + { + // Pass the same context parameter for ILambdaContext that comes from the generated method. + if (p.Type.FullName == TypeFullNames.ILambdaContext) + { + return "__context__"; + } + + // Pass the same request parameter for ALB Request Type that comes from the generated method. + if (TypeFullNames.ALBRequests.Contains(p.Type.FullName)) + { + return "__request__"; + } + + return p.Name; + })); + + var albApiAttribute = _model.LambdaMethod.Attributes.FirstOrDefault(att => att.Type.FullName == TypeFullNames.ALBApiAttribute) as AttributeModel; + + // Determine whether multi-value headers are enabled + var useMultiValue = albApiAttribute?.Data?.IsMultiValueHeadersSet == true && albApiAttribute.Data.MultiValueHeaders; + + if (_model.LambdaMethod.Parameters.HasConvertibleParameter()) + { + + this.Write(" var validationErrors = new List();\r\n\r\n"); + + } + + foreach (var parameter in _model.LambdaMethod.Parameters) + { + if (parameter.Type.FullName == TypeFullNames.ILambdaContext || TypeFullNames.ALBRequests.Contains(parameter.Type.FullName)) + { + // No action required for ILambdaContext and ALB RequestType, they are passed from the generated method parameter directly to the original method. + } + else if (parameter.Attributes.Any(att => att.Type.FullName == TypeFullNames.FromServiceAttribute)) + { + + this.Write(" var "); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Name)); + this.Write(" = scope.ServiceProvider.GetRequiredService<"); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Type.FullName)); + this.Write(">();\r\n"); + + } + else if (parameter.Attributes.Any(att => att.Type.FullName == TypeFullNames.ALBFromQueryAttribute)) + { + var fromQueryAttribute = parameter.Attributes.First(att => att.Type.FullName == TypeFullNames.ALBFromQueryAttribute) as AttributeModel; + + // Use parameter name as key, if Name has not specified explicitly in the attribute definition. + var parameterKey = fromQueryAttribute?.Data?.Name ?? parameter.Name; + + var queryStringParameters = useMultiValue ? "MultiValueQueryStringParameters" : "QueryStringParameters"; + + this.Write(" var "); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Name)); + this.Write(" = default("); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Type.FullName)); + this.Write(");\r\n"); + + if (parameter.Type.IsEnumerable && parameter.Type.IsGenericType) + { + if (useMultiValue) + { + // Multi-value mode: MultiValueQueryStringParameters is IDictionary> + if (parameter.Type.TypeArguments.Count != 1) + { + throw new NotSupportedException("Only one type argument is supported for generic types."); + } + + var typeArgument = parameter.Type.TypeArguments.First(); + this.Write(" if (__request__."); + this.Write(this.ToStringHelper.ToStringWithCulture(queryStringParameters)); + this.Write("?.ContainsKey(\""); + this.Write(this.ToStringHelper.ToStringWithCulture(parameterKey)); + this.Write("\") == true)\r\n {\r\n "); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Name)); + this.Write(" = __request__."); + this.Write(this.ToStringHelper.ToStringWithCulture(queryStringParameters)); + this.Write("[\""); + this.Write(this.ToStringHelper.ToStringWithCulture(parameterKey)); + this.Write("\"]\r\n .Select(q =>\r\n {\r\n try\r\n {\r\n return ("); + this.Write(this.ToStringHelper.ToStringWithCulture(typeArgument.FullName)); + this.Write(")Convert.ChangeType(q, typeof("); + this.Write(this.ToStringHelper.ToStringWithCulture(typeArgument.FullNameWithoutAnnotations)); + this.Write("));\r\n }\r\n catch (Exception e) when (e is InvalidCastException || e is FormatException || e is OverflowException || e is ArgumentException)\r\n {\r\n validationErrors.Add($\"Value {q} at \'"); + this.Write(this.ToStringHelper.ToStringWithCulture(parameterKey)); + this.Write("\' failed to satisfy constraint: {e.Message}\");\r\n return default;\r\n }\r\n })\r\n .ToList();\r\n }\r\n\r\n"); + + } + else + { + // Single-value mode: QueryStringParameters is IDictionary + // Split by comma to support multiple values + if (parameter.Type.TypeArguments.Count != 1) + { + throw new NotSupportedException("Only one type argument is supported for generic types."); + } + + var typeArgument = parameter.Type.TypeArguments.First(); + this.Write(" if (__request__."); + this.Write(this.ToStringHelper.ToStringWithCulture(queryStringParameters)); + this.Write("?.ContainsKey(\""); + this.Write(this.ToStringHelper.ToStringWithCulture(parameterKey)); + this.Write("\") == true)\r\n {\r\n "); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Name)); + this.Write(" = __request__."); + this.Write(this.ToStringHelper.ToStringWithCulture(queryStringParameters)); + this.Write("[\""); + this.Write(this.ToStringHelper.ToStringWithCulture(parameterKey)); + this.Write("\"].Split(\",\")\r\n .Select(q =>\r\n {\r\n try\r\n {\r\n return ("); + this.Write(this.ToStringHelper.ToStringWithCulture(typeArgument.FullName)); + this.Write(")Convert.ChangeType(q, typeof("); + this.Write(this.ToStringHelper.ToStringWithCulture(typeArgument.FullNameWithoutAnnotations)); + this.Write("));\r\n }\r\n catch (Exception e) when (e is InvalidCastException || e is FormatException || e is OverflowException || e is ArgumentException)\r\n {\r\n validationErrors.Add($\"Value {q} at \'"); + this.Write(this.ToStringHelper.ToStringWithCulture(parameterKey)); + this.Write("\' failed to satisfy constraint: {e.Message}\");\r\n return default;\r\n }\r\n })\r\n .ToList();\r\n }\r\n\r\n"); + + } + } + else + { + // Non-generic types are mapped directly to the target parameter. + this.Write(" if (__request__."); + this.Write(this.ToStringHelper.ToStringWithCulture(queryStringParameters)); + this.Write("?.ContainsKey(\""); + this.Write(this.ToStringHelper.ToStringWithCulture(parameterKey)); + this.Write("\") == true)\r\n {\r\n try\r\n {\r\n "); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Name)); + this.Write(" = ("); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Type.FullName)); + this.Write(")Convert.ChangeType(__request__."); + this.Write(this.ToStringHelper.ToStringWithCulture(queryStringParameters)); + this.Write("[\""); + this.Write(this.ToStringHelper.ToStringWithCulture(parameterKey)); + this.Write("\"], typeof("); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Type.FullNameWithoutAnnotations)); + this.Write("));\r\n }\r\n catch (Exception e) when (e is InvalidCastException || e is FormatException || e is OverflowException || e is ArgumentException)\r\n {\r\n validationErrors.Add($\"Value {__request__."); + this.Write(this.ToStringHelper.ToStringWithCulture(queryStringParameters)); + this.Write("[\""); + this.Write(this.ToStringHelper.ToStringWithCulture(parameterKey)); + this.Write("\"]} at \'"); + this.Write(this.ToStringHelper.ToStringWithCulture(parameterKey)); + this.Write("\' failed to satisfy constraint: {e.Message}\");\r\n }\r\n }\r\n\r\n"); + + } + + } + else if (parameter.Attributes.Any(att => att.Type.FullName == TypeFullNames.ALBFromHeaderAttribute)) + { + var fromHeaderAttribute = + parameter.Attributes.First(att => att.Type.FullName == TypeFullNames.ALBFromHeaderAttribute) as + AttributeModel; + + // Use parameter name as key, if Name has not specified explicitly in the attribute definition. + var headerKey = fromHeaderAttribute?.Data?.Name ?? parameter.Name; + + var headers = useMultiValue ? "MultiValueHeaders" : "Headers"; + + this.Write(" var "); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Name)); + this.Write(" = default("); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Type.FullName)); + this.Write(");\r\n"); + + if (parameter.Type.IsEnumerable && parameter.Type.IsGenericType) + { + if (useMultiValue) + { + // Multi-value mode: MultiValueHeaders is IDictionary> + if (parameter.Type.TypeArguments.Count != 1) + { + throw new NotSupportedException("Only one type argument is supported for generic types."); + } + + var typeArgument = parameter.Type.TypeArguments.First(); + this.Write(" if (__request__."); + this.Write(this.ToStringHelper.ToStringWithCulture(headers)); + this.Write("?.Any(x => string.Equals(x.Key, \""); + this.Write(this.ToStringHelper.ToStringWithCulture(headerKey)); + this.Write("\", StringComparison.OrdinalIgnoreCase)) == true)\r\n {\r\n "); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Name)); + this.Write(" = __request__."); + this.Write(this.ToStringHelper.ToStringWithCulture(headers)); + this.Write(".First(x => string.Equals(x.Key, \""); + this.Write(this.ToStringHelper.ToStringWithCulture(headerKey)); + this.Write("\", StringComparison.OrdinalIgnoreCase)).Value\r\n .Select(q =>\r\n {\r\n try\r\n {\r\n return ("); + this.Write(this.ToStringHelper.ToStringWithCulture(typeArgument.FullName)); + this.Write(")Convert.ChangeType(q, typeof("); + this.Write(this.ToStringHelper.ToStringWithCulture(typeArgument.FullNameWithoutAnnotations)); + this.Write("));\r\n }\r\n catch (Exception e) when (e is InvalidCastException || e is FormatException || e is OverflowException || e is ArgumentException)\r\n {\r\n validationErrors.Add($\"Value {q} at \'"); + this.Write(this.ToStringHelper.ToStringWithCulture(headerKey)); + this.Write("\' failed to satisfy constraint: {e.Message}\");\r\n return default;\r\n }\r\n })\r\n .ToList();\r\n }\r\n\r\n"); + + } + else + { + // Single-value mode: Headers is IDictionary + // Split by comma to support multiple values + if (parameter.Type.TypeArguments.Count != 1) + { + throw new NotSupportedException("Only one type argument is supported for generic types."); + } + + var typeArgument = parameter.Type.TypeArguments.First(); + this.Write(" if (__request__."); + this.Write(this.ToStringHelper.ToStringWithCulture(headers)); + this.Write("?.Any(x => string.Equals(x.Key, \""); + this.Write(this.ToStringHelper.ToStringWithCulture(headerKey)); + this.Write("\", StringComparison.OrdinalIgnoreCase)) == true)\r\n {\r\n "); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Name)); + this.Write(" = __request__."); + this.Write(this.ToStringHelper.ToStringWithCulture(headers)); + this.Write(".First(x => string.Equals(x.Key, \""); + this.Write(this.ToStringHelper.ToStringWithCulture(headerKey)); + this.Write("\", StringComparison.OrdinalIgnoreCase)).Value.Split(\",\")\r\n .Select(q =>\r\n {\r\n try\r\n {\r\n return ("); + this.Write(this.ToStringHelper.ToStringWithCulture(typeArgument.FullName)); + this.Write(")Convert.ChangeType(q, typeof("); + this.Write(this.ToStringHelper.ToStringWithCulture(typeArgument.FullNameWithoutAnnotations)); + this.Write("));\r\n }\r\n catch (Exception e) when (e is InvalidCastException || e is FormatException || e is OverflowException || e is ArgumentException)\r\n {\r\n validationErrors.Add($\"Value {q} at \'"); + this.Write(this.ToStringHelper.ToStringWithCulture(headerKey)); + this.Write("\' failed to satisfy constraint: {e.Message}\");\r\n return default;\r\n }\r\n })\r\n .ToList();\r\n }\r\n\r\n"); + + } + } + else + { + // Non-generic types are mapped directly to the target parameter. + this.Write(" if (__request__."); + this.Write(this.ToStringHelper.ToStringWithCulture(headers)); + this.Write("?.Any(x => string.Equals(x.Key, \""); + this.Write(this.ToStringHelper.ToStringWithCulture(headerKey)); + this.Write("\", StringComparison.OrdinalIgnoreCase)) == true)\r\n {\r\n try\r\n {\r\n "); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Name)); + this.Write(" = ("); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Type.FullName)); + this.Write(")Convert.ChangeType(__request__."); + this.Write(this.ToStringHelper.ToStringWithCulture(headers)); + this.Write(".First(x => string.Equals(x.Key, \""); + this.Write(this.ToStringHelper.ToStringWithCulture(headerKey)); + this.Write("\", StringComparison.OrdinalIgnoreCase)).Value, typeof("); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Type.FullNameWithoutAnnotations)); + this.Write("));\r\n }\r\n catch (Exception e) when (e is InvalidCastException || e is FormatException || e is OverflowException || e is ArgumentException)\r\n {\r\n validationErrors.Add($\"Value {__request__."); + this.Write(this.ToStringHelper.ToStringWithCulture(headers)); + this.Write(".First(x => string.Equals(x.Key, \""); + this.Write(this.ToStringHelper.ToStringWithCulture(headerKey)); + this.Write("\", StringComparison.OrdinalIgnoreCase)).Value} at \'"); + this.Write(this.ToStringHelper.ToStringWithCulture(headerKey)); + this.Write("\' failed to satisfy constraint: {e.Message}\");\r\n }\r\n }\r\n\r\n"); + + } + } + else if (parameter.Attributes.Any(att => att.Type.FullName == TypeFullNames.ALBFromBodyAttribute)) + { + // string parameter does not need to be de-serialized + if (parameter.Type.IsString()) + { + + this.Write(" var "); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Name)); + this.Write(" = __request__.Body;\r\n\r\n"); + + } + else + { + + this.Write(" var "); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Name)); + this.Write(" = default("); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Type.FullName)); + this.Write(");\r\n try\r\n {\r\n // convert string to stream\r\n var byteArray = Encoding.UTF8.GetBytes(__request__.Body);\r\n var stream = new MemoryStream(byteArray);\r\n "); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Name)); + this.Write(" = serializer.Deserialize<"); + this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Type.FullName)); + this.Write(">(stream);\r\n }\r\n catch (Exception e)\r\n {\r\n validationErrors.Add($\"Value {__request__.Body} at \'body\' failed to satisfy constraint: {e.Message}\");\r\n }\r\n\r\n"); + + } + } + else + { + throw new NotSupportedException($"{parameter.Name} parameter of type {parameter.Type.FullName} passing is not supported for ALB functions. Use [FromHeader], [FromQuery], [FromBody], or [FromServices] attributes."); + } + } + + if (_model.LambdaMethod.Parameters.HasConvertibleParameter()) + { + + this.Write(" // return 400 Bad Request if there exists a validation error\r\n" + + " if (validationErrors.Any())\r\n" + + " {\r\n" + + " var errorResult = new Amazon.Lambda.ApplicationLoadBalancerEvents.ApplicationLoadBalancerResponse\r\n" + + " {\r\n" + + " Body = @$\"{{\"\"message\"\": \"\"{validationErrors.Count} validation error(s) detected: {string.Join(\",\", validationErrors)}\"\"}}\",\r\n" + + " Headers = new Dictionary\r\n" + + " {\r\n" + + " {\"Content-Type\", \"application/json\"}\r\n" + + " },\r\n" + + " StatusCode = 400\r\n" + + " };\r\n" + + " return errorResult;\r\n" + + " }\r\n\r\n"); + + } + + return this.GenerationEnvironment.ToString(); + } + } + + #region Base class + /// + /// Base class for this transformation + /// + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "18.0.0.0")] + public class ALBSetupParametersBase + { + #region Fields + private global::System.Text.StringBuilder generationEnvironmentField; + private global::System.CodeDom.Compiler.CompilerErrorCollection errorsField; + private global::System.Collections.Generic.List indentLengthsField; + private string currentIndentField = ""; + private bool endsWithNewline; + private global::System.Collections.Generic.IDictionary sessionField; + #endregion + #region Properties + /// + /// The string builder that generation-time code is using to assemble generated output + /// + public System.Text.StringBuilder GenerationEnvironment + { + get + { + if ((this.generationEnvironmentField == null)) + { + this.generationEnvironmentField = new global::System.Text.StringBuilder(); + } + return this.generationEnvironmentField; + } + set + { + this.generationEnvironmentField = value; + } + } + /// + /// The error collection for the generation process + /// + public System.CodeDom.Compiler.CompilerErrorCollection Errors + { + get + { + if ((this.errorsField == null)) + { + this.errorsField = new global::System.CodeDom.Compiler.CompilerErrorCollection(); + } + return this.errorsField; + } + } + /// + /// A list of the lengths of each indent that was added with PushIndent + /// + private System.Collections.Generic.List indentLengths + { + get + { + if ((this.indentLengthsField == null)) + { + this.indentLengthsField = new global::System.Collections.Generic.List(); + } + return this.indentLengthsField; + } + } + /// + /// Gets the current indent we use when adding lines to the output + /// + public string CurrentIndent + { + get + { + return this.currentIndentField; + } + } + /// + /// Current transformation session + /// + public virtual global::System.Collections.Generic.IDictionary Session + { + get + { + return this.sessionField; + } + set + { + this.sessionField = value; + } + } + #endregion + #region Transform-time helpers + /// + /// Write text directly into the generated output + /// + public void Write(string textToAppend) + { + if (string.IsNullOrEmpty(textToAppend)) + { + return; + } + if (((this.GenerationEnvironment.Length == 0) + || this.endsWithNewline)) + { + this.GenerationEnvironment.Append(this.currentIndentField); + this.endsWithNewline = false; + } + if (textToAppend.EndsWith(global::System.Environment.NewLine, global::System.StringComparison.CurrentCulture)) + { + this.endsWithNewline = true; + } + if ((this.currentIndentField.Length == 0)) + { + this.GenerationEnvironment.Append(textToAppend); + return; + } + textToAppend = textToAppend.Replace(global::System.Environment.NewLine, (global::System.Environment.NewLine + this.currentIndentField)); + if (this.endsWithNewline) + { + this.GenerationEnvironment.Append(textToAppend, 0, (textToAppend.Length - this.currentIndentField.Length)); + } + else + { + this.GenerationEnvironment.Append(textToAppend); + } + } + /// + /// Write text directly into the generated output + /// + public void WriteLine(string textToAppend) + { + this.Write(textToAppend); + this.GenerationEnvironment.AppendLine(); + this.endsWithNewline = true; + } + /// + /// Write formatted text directly into the generated output + /// + public void Write(string format, params object[] args) + { + this.Write(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args)); + } + /// + /// Write formatted text directly into the generated output + /// + public void WriteLine(string format, params object[] args) + { + this.WriteLine(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args)); + } + /// + /// Raise an error + /// + public void Error(string message) + { + System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError(); + error.ErrorText = message; + this.Errors.Add(error); + } + /// + /// Raise a warning + /// + public void Warning(string message) + { + System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError(); + error.ErrorText = message; + error.IsWarning = true; + this.Errors.Add(error); + } + /// + /// Increase the indent + /// + public void PushIndent(string indent) + { + if ((indent == null)) + { + throw new global::System.ArgumentNullException("indent"); + } + this.currentIndentField = (this.currentIndentField + indent); + this.indentLengths.Add(indent.Length); + } + /// + /// Remove the last indent that was added with PushIndent + /// + public string PopIndent() + { + string returnValue = ""; + if ((this.indentLengths.Count > 0)) + { + int indentLength = this.indentLengths[(this.indentLengths.Count - 1)]; + this.indentLengths.RemoveAt((this.indentLengths.Count - 1)); + if ((indentLength > 0)) + { + returnValue = this.currentIndentField.Substring((this.currentIndentField.Length - indentLength)); + this.currentIndentField = this.currentIndentField.Remove((this.currentIndentField.Length - indentLength)); + } + } + return returnValue; + } + /// + /// Remove any indentation + /// + public void ClearIndent() + { + this.indentLengths.Clear(); + this.currentIndentField = ""; + } + #endregion + #region ToString Helpers + /// + /// Utility class to produce culture-oriented representation of an object as a string. + /// + public class ToStringInstanceHelper + { + private System.IFormatProvider formatProviderField = global::System.Globalization.CultureInfo.InvariantCulture; + /// + /// Gets or sets format provider to be used by ToStringWithCulture method. + /// + public System.IFormatProvider FormatProvider + { + get + { + return this.formatProviderField ; + } + set + { + if ((value != null)) + { + this.formatProviderField = value; + } + } + } + /// + /// This is called from the compile/run appdomain to convert objects within an expression block to a string + /// + public string ToStringWithCulture(object objectToConvert) + { + if ((objectToConvert == null)) + { + throw new global::System.ArgumentNullException("objectToConvert"); + } + System.Type t = objectToConvert.GetType(); + System.Reflection.MethodInfo method = t.GetMethod("ToString", new System.Type[] { + typeof(System.IFormatProvider)}); + if ((method == null)) + { + return objectToConvert.ToString(); + } + else + { + return ((string)(method.Invoke(objectToConvert, new object[] { + this.formatProviderField }))); + } + } + } + private ToStringInstanceHelper toStringHelperField = new ToStringInstanceHelper(); + /// + /// Helper to produce culture-oriented representation of an object as a string + /// + public ToStringInstanceHelper ToStringHelper + { + get + { + return this.toStringHelperField; + } + } + #endregion + } + #endregion +} diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/ALBSetupParameters.tt b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/ALBSetupParameters.tt new file mode 100644 index 000000000..c90cff8b6 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/ALBSetupParameters.tt @@ -0,0 +1,303 @@ +<#@ template language="C#" #> +<#@ assembly name="System.Core" #> +<#@ import namespace="System.Linq" #> +<#@ import namespace="System.Text" #> +<#@ import namespace="System.Collections.Generic" #> +<#@ import namespace="Amazon.Lambda.Annotations.SourceGenerator.Extensions" #> +<#@ import namespace="Amazon.Lambda.Annotations.SourceGenerator.Models" #> +<#@ import namespace="Amazon.Lambda.Annotations.SourceGenerator.Models.Attributes" #> +<# + ParameterSignature = string.Join(", ", _model.LambdaMethod.Parameters + .Select(p => + { + // Pass the same context parameter for ILambdaContext that comes from the generated method. + if (p.Type.FullName == TypeFullNames.ILambdaContext) + { + return "__context__"; + } + + // Pass the same request parameter for ALB Request Type that comes from the generated method. + if (TypeFullNames.ALBRequests.Contains(p.Type.FullName)) + { + return "__request__"; + } + + return p.Name; + })); + + var albApiAttribute = _model.LambdaMethod.Attributes.FirstOrDefault(att => att.Type.FullName == TypeFullNames.ALBApiAttribute) as AttributeModel; + + // Determine whether multi-value headers are enabled + var useMultiValue = albApiAttribute?.Data?.IsMultiValueHeadersSet == true && albApiAttribute.Data.MultiValueHeaders; + + if (_model.LambdaMethod.Parameters.HasConvertibleParameter()) + { +#> + var validationErrors = new List(); + +<# + } + + foreach (var parameter in _model.LambdaMethod.Parameters) + { + if (parameter.Type.FullName == TypeFullNames.ILambdaContext || TypeFullNames.ALBRequests.Contains(parameter.Type.FullName)) + { + // No action required for ILambdaContext and ALB RequestType, they are passed from the generated method parameter directly to the original method. + } + else if (parameter.Attributes.Any(att => att.Type.FullName == TypeFullNames.FromServiceAttribute)) + { +#> + var <#= parameter.Name #> = scope.ServiceProvider.GetRequiredService<<#= parameter.Type.FullName #>>(); +<# + } + else if (parameter.Attributes.Any(att => att.Type.FullName == TypeFullNames.ALBFromQueryAttribute)) + { + var fromQueryAttribute = parameter.Attributes.First(att => att.Type.FullName == TypeFullNames.ALBFromQueryAttribute) as AttributeModel; + + // Use parameter name as key, if Name has not specified explicitly in the attribute definition. + var parameterKey = fromQueryAttribute?.Data?.Name ?? parameter.Name; + + var queryStringParameters = useMultiValue ? "MultiValueQueryStringParameters" : "QueryStringParameters"; + +#> + var <#= parameter.Name #> = default(<#= parameter.Type.FullName #>); +<# + + if (parameter.Type.IsEnumerable && parameter.Type.IsGenericType) + { + if (useMultiValue) + { + // Multi-value mode: MultiValueQueryStringParameters is IDictionary> + if (parameter.Type.TypeArguments.Count != 1) + { + throw new NotSupportedException("Only one type argument is supported for generic types."); + } + + var typeArgument = parameter.Type.TypeArguments.First(); +#> + if (__request__.<#= queryStringParameters #>?.ContainsKey("<#= parameterKey #>") == true) + { + <#= parameter.Name #> = __request__.<#= queryStringParameters #>["<#= parameterKey #>"] + .Select(q => + { + try + { + return (<#= typeArgument.FullName #>)Convert.ChangeType(q, typeof(<#= typeArgument.FullNameWithoutAnnotations #>)); + } + catch (Exception e) when (e is InvalidCastException || e is FormatException || e is OverflowException || e is ArgumentException) + { + validationErrors.Add($"Value {q} at '<#= parameterKey #>' failed to satisfy constraint: {e.Message}"); + return default; + } + }) + .ToList(); + } + +<# + } + else + { + // Single-value mode: QueryStringParameters is IDictionary + // Split by comma to support multiple values + if (parameter.Type.TypeArguments.Count != 1) + { + throw new NotSupportedException("Only one type argument is supported for generic types."); + } + + var typeArgument = parameter.Type.TypeArguments.First(); +#> + if (__request__.<#= queryStringParameters #>?.ContainsKey("<#= parameterKey #>") == true) + { + <#= parameter.Name #> = __request__.<#= queryStringParameters #>["<#= parameterKey #>"].Split(",") + .Select(q => + { + try + { + return (<#= typeArgument.FullName #>)Convert.ChangeType(q, typeof(<#= typeArgument.FullNameWithoutAnnotations #>)); + } + catch (Exception e) when (e is InvalidCastException || e is FormatException || e is OverflowException || e is ArgumentException) + { + validationErrors.Add($"Value {q} at '<#= parameterKey #>' failed to satisfy constraint: {e.Message}"); + return default; + } + }) + .ToList(); + } + +<# + } + } + else + { + // Non-generic types are mapped directly to the target parameter. +#> + if (__request__.<#= queryStringParameters #>?.ContainsKey("<#= parameterKey #>") == true) + { + try + { + <#= parameter.Name #> = (<#= parameter.Type.FullName #>)Convert.ChangeType(__request__.<#= queryStringParameters #>["<#= parameterKey #>"], typeof(<#= parameter.Type.FullNameWithoutAnnotations #>)); + } + catch (Exception e) when (e is InvalidCastException || e is FormatException || e is OverflowException || e is ArgumentException) + { + validationErrors.Add($"Value {__request__.<#= queryStringParameters #>["<#= parameterKey #>"]} at '<#= parameterKey #>' failed to satisfy constraint: {e.Message}"); + } + } + +<# + } + + } + else if (parameter.Attributes.Any(att => att.Type.FullName == TypeFullNames.ALBFromHeaderAttribute)) + { + var fromHeaderAttribute = + parameter.Attributes.First(att => att.Type.FullName == TypeFullNames.ALBFromHeaderAttribute) as + AttributeModel; + + // Use parameter name as key, if Name has not specified explicitly in the attribute definition. + var headerKey = fromHeaderAttribute?.Data?.Name ?? parameter.Name; + + var headers = useMultiValue ? "MultiValueHeaders" : "Headers"; + +#> + var <#= parameter.Name #> = default(<#= parameter.Type.FullName #>); +<# + + if (parameter.Type.IsEnumerable && parameter.Type.IsGenericType) + { + if (useMultiValue) + { + // Multi-value mode: MultiValueHeaders is IDictionary> + if (parameter.Type.TypeArguments.Count != 1) + { + throw new NotSupportedException("Only one type argument is supported for generic types."); + } + + var typeArgument = parameter.Type.TypeArguments.First(); +#> + if (__request__.<#= headers #>?.Any(x => string.Equals(x.Key, "<#= headerKey #>", StringComparison.OrdinalIgnoreCase)) == true) + { + <#= parameter.Name #> = __request__.<#= headers #>.First(x => string.Equals(x.Key, "<#= headerKey #>", StringComparison.OrdinalIgnoreCase)).Value + .Select(q => + { + try + { + return (<#= typeArgument.FullName #>)Convert.ChangeType(q, typeof(<#= typeArgument.FullNameWithoutAnnotations #>)); + } + catch (Exception e) when (e is InvalidCastException || e is FormatException || e is OverflowException || e is ArgumentException) + { + validationErrors.Add($"Value {q} at '<#= headerKey #>' failed to satisfy constraint: {e.Message}"); + return default; + } + }) + .ToList(); + } + +<# + } + else + { + // Single-value mode: Headers is IDictionary + // Split by comma to support multiple values + if (parameter.Type.TypeArguments.Count != 1) + { + throw new NotSupportedException("Only one type argument is supported for generic types."); + } + + var typeArgument = parameter.Type.TypeArguments.First(); +#> + if (__request__.<#= headers #>?.Any(x => string.Equals(x.Key, "<#= headerKey #>", StringComparison.OrdinalIgnoreCase)) == true) + { + <#= parameter.Name #> = __request__.<#= headers #>.First(x => string.Equals(x.Key, "<#= headerKey #>", StringComparison.OrdinalIgnoreCase)).Value.Split(",") + .Select(q => + { + try + { + return (<#= typeArgument.FullName #>)Convert.ChangeType(q, typeof(<#= typeArgument.FullNameWithoutAnnotations #>)); + } + catch (Exception e) when (e is InvalidCastException || e is FormatException || e is OverflowException || e is ArgumentException) + { + validationErrors.Add($"Value {q} at '<#= headerKey #>' failed to satisfy constraint: {e.Message}"); + return default; + } + }) + .ToList(); + } + +<# + } + } + else + { + // Non-generic types are mapped directly to the target parameter. +#> + if (__request__.<#= headers #>?.Any(x => string.Equals(x.Key, "<#= headerKey #>", StringComparison.OrdinalIgnoreCase)) == true) + { + try + { + <#= parameter.Name #> = (<#= parameter.Type.FullName #>)Convert.ChangeType(__request__.<#= headers #>.First(x => string.Equals(x.Key, "<#= headerKey #>", StringComparison.OrdinalIgnoreCase)).Value, typeof(<#= parameter.Type.FullNameWithoutAnnotations #>)); + } + catch (Exception e) when (e is InvalidCastException || e is FormatException || e is OverflowException || e is ArgumentException) + { + validationErrors.Add($"Value {__request__.<#= headers #>.First(x => string.Equals(x.Key, "<#= headerKey #>", StringComparison.OrdinalIgnoreCase)).Value} at '<#= headerKey #>' failed to satisfy constraint: {e.Message}"); + } + } + +<# + } + } + else if (parameter.Attributes.Any(att => att.Type.FullName == TypeFullNames.ALBFromBodyAttribute)) + { + // string parameter does not need to be de-serialized + if (parameter.Type.IsString()) + { + #> + var <#= parameter.Name #> = __request__.Body; + +<# + } + else + { + #> + var <#= parameter.Name #> = default(<#= parameter.Type.FullName #>); + try + { + // convert string to stream + var byteArray = Encoding.UTF8.GetBytes(__request__.Body); + var stream = new MemoryStream(byteArray); + <#= parameter.Name #> = serializer.Deserialize<<#= parameter.Type.FullName #>>(stream); + } + catch (Exception e) + { + validationErrors.Add($"Value {__request__.Body} at 'body' failed to satisfy constraint: {e.Message}"); + } + +<# + } + } + else + { + throw new NotSupportedException($"{parameter.Name} parameter of type {parameter.Type.FullName} passing is not supported for ALB functions. Use [FromHeader], [FromQuery], [FromBody], or [FromServices] attributes."); + } + } + + if (_model.LambdaMethod.Parameters.HasConvertibleParameter()) + { +#> + // return 400 Bad Request if there exists a validation error + if (validationErrors.Any()) + { + var errorResult = new Amazon.Lambda.ApplicationLoadBalancerEvents.ApplicationLoadBalancerResponse + { + Body = @$"{{""message"": ""{validationErrors.Count} validation error(s) detected: {string.Join(",", validationErrors)}""}}", + Headers = new Dictionary + { + {"Content-Type", "application/json"} + }, + StatusCode = 400 + }; + return errorResult; + } + +<# + } +#> diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/ALBSetupParametersCode.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/ALBSetupParametersCode.cs new file mode 100644 index 000000000..678f28859 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/ALBSetupParametersCode.cs @@ -0,0 +1,19 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using Amazon.Lambda.Annotations.SourceGenerator.Models; + +namespace Amazon.Lambda.Annotations.SourceGenerator.Templates +{ + public partial class ALBSetupParameters + { + private readonly LambdaFunctionModel _model; + + public string ParameterSignature { get; set; } + + public ALBSetupParameters(LambdaFunctionModel model) + { + _model = model; + } + } +} diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/LambdaFunctionTemplate.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/LambdaFunctionTemplate.cs index e2c2f957f..6e4a30347 100644 --- a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/LambdaFunctionTemplate.cs +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/LambdaFunctionTemplate.cs @@ -188,6 +188,12 @@ public virtual string TransformText() this.Write(apiParameters.TransformText()); this.Write(new APIGatewayInvoke(_model, apiParameters.ParameterSignature).TransformText()); } + else if (_model.LambdaMethod.Events.Contains(EventType.ALB)) + { + var albParameters = new ALBSetupParameters(_model); + this.Write(albParameters.TransformText()); + this.Write(new ALBInvoke(_model, albParameters.ParameterSignature).TransformText()); + } else { this.Write(new NoEventMethodBody(_model).TransformText()); diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/LambdaFunctionTemplate.tt b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/LambdaFunctionTemplate.tt index bacf7daf0..aa3c3ab18 100644 --- a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/LambdaFunctionTemplate.tt +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Templates/LambdaFunctionTemplate.tt @@ -66,6 +66,12 @@ this.Write(new FieldsAndConstructor(_model).TransformText()); this.Write(apiParameters.TransformText()); this.Write(new APIGatewayInvoke(_model, apiParameters.ParameterSignature).TransformText()); } + else if (_model.LambdaMethod.Events.Contains(EventType.ALB)) + { + var albParameters = new ALBSetupParameters(_model); + this.Write(albParameters.TransformText()); + this.Write(new ALBInvoke(_model, albParameters.ParameterSignature).TransformText()); + } else { this.Write(new NoEventMethodBody(_model).TransformText()); diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/TypeFullNames.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/TypeFullNames.cs index 6e15c2175..4c66c1875 100644 --- a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/TypeFullNames.cs +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/TypeFullNames.cs @@ -1,3 +1,6 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + using System.Collections.Generic; namespace Amazon.Lambda.Annotations.SourceGenerator @@ -34,6 +37,9 @@ public static class TypeFullNames public const string FromRouteAttribute = "Amazon.Lambda.Annotations.APIGateway.FromRouteAttribute"; public const string FromCustomAuthorizerAttribute = "Amazon.Lambda.Annotations.APIGateway.FromCustomAuthorizerAttribute"; + public const string FunctionUrlAttribute = "Amazon.Lambda.Annotations.APIGateway.FunctionUrlAttribute"; + public const string FunctionUrlAuthType = "Amazon.Lambda.Annotations.APIGateway.FunctionUrlAuthType"; + public const string HttpApiAuthorizerAttribute = "Amazon.Lambda.Annotations.APIGateway.HttpApiAuthorizerAttribute"; public const string RestApiAuthorizerAttribute = "Amazon.Lambda.Annotations.APIGateway.RestApiAuthorizerAttribute"; @@ -46,6 +52,16 @@ public static class TypeFullNames public const string SQSBatchResponse = "Amazon.Lambda.SQSEvents.SQSBatchResponse"; public const string SQSEventAttribute = "Amazon.Lambda.Annotations.SQS.SQSEventAttribute"; + public const string ApplicationLoadBalancerRequest = "Amazon.Lambda.ApplicationLoadBalancerEvents.ApplicationLoadBalancerRequest"; + public const string ApplicationLoadBalancerResponse = "Amazon.Lambda.ApplicationLoadBalancerEvents.ApplicationLoadBalancerResponse"; + public const string ALBApiAttribute = "Amazon.Lambda.Annotations.ALB.ALBApiAttribute"; + public const string ALBFromQueryAttribute = "Amazon.Lambda.Annotations.ALB.FromQueryAttribute"; + public const string ALBFromHeaderAttribute = "Amazon.Lambda.Annotations.ALB.FromHeaderAttribute"; + public const string ALBFromBodyAttribute = "Amazon.Lambda.Annotations.ALB.FromBodyAttribute"; + + public const string S3Event = "Amazon.Lambda.S3Events.S3Event"; + public const string S3EventAttribute = "Amazon.Lambda.Annotations.S3.S3EventAttribute"; + public const string LambdaSerializerAttribute = "Amazon.Lambda.Core.LambdaSerializerAttribute"; public const string DefaultLambdaSerializer = "Amazon.Lambda.Serialization.SystemTextJson.DefaultLambdaJsonSerializer"; @@ -63,11 +79,19 @@ public static class TypeFullNames APIGatewayCustomAuthorizerRequest }; + public static HashSet ALBRequests = new HashSet + { + ApplicationLoadBalancerRequest + }; + public static HashSet Events = new HashSet { RestApiAttribute, HttpApiAttribute, - SQSEventAttribute + FunctionUrlAttribute, + SQSEventAttribute, + ALBApiAttribute, + S3EventAttribute }; } -} \ No newline at end of file +} diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Validation/LambdaFunctionValidator.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Validation/LambdaFunctionValidator.cs index 733124209..4ea09acdf 100644 --- a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Validation/LambdaFunctionValidator.cs +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Validation/LambdaFunctionValidator.cs @@ -1,4 +1,9 @@ -using Amazon.Lambda.Annotations.SourceGenerator.Diagnostics; +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using Amazon.Lambda.Annotations.ALB; +using Amazon.Lambda.Annotations.S3; +using Amazon.Lambda.Annotations.SourceGenerator.Diagnostics; using Amazon.Lambda.Annotations.SourceGenerator.Extensions; using Amazon.Lambda.Annotations.SourceGenerator.Models; using Amazon.Lambda.Annotations.SourceGenerator.Models.Attributes; @@ -59,6 +64,8 @@ internal static bool ValidateFunction(GeneratorExecutionContext context, IMethod // Validate Events ValidateApiGatewayEvents(lambdaFunctionModel, methodLocation, diagnostics); ValidateSqsEvents(lambdaFunctionModel, methodLocation, diagnostics); + ValidateAlbEvents(lambdaFunctionModel, methodLocation, diagnostics); + ValidateS3Events(lambdaFunctionModel, methodLocation, diagnostics); return ReportDiagnostics(diagnosticReporter, diagnostics); } @@ -67,6 +74,7 @@ internal static bool ValidateDependencies(GeneratorExecutionContext context, IMe { // Check for references to "Amazon.Lambda.APIGatewayEvents" if the Lambda method is annotated with RestApi, HttpApi, or authorizer attributes. if (lambdaMethodSymbol.HasAttribute(context, TypeFullNames.RestApiAttribute) || lambdaMethodSymbol.HasAttribute(context, TypeFullNames.HttpApiAttribute) + || lambdaMethodSymbol.HasAttribute(context, TypeFullNames.FunctionUrlAttribute) || lambdaMethodSymbol.HasAttribute(context, TypeFullNames.HttpApiAuthorizerAttribute) || lambdaMethodSymbol.HasAttribute(context, TypeFullNames.RestApiAuthorizerAttribute)) { if (context.Compilation.ReferencedAssemblyNames.FirstOrDefault(x => x.Name == "Amazon.Lambda.APIGatewayEvents") == null) @@ -86,6 +94,26 @@ internal static bool ValidateDependencies(GeneratorExecutionContext context, IMe } } + // Check for references to "Amazon.Lambda.ApplicationLoadBalancerEvents" if the Lambda method is annotated with ALBApi attribute. + if (lambdaMethodSymbol.HasAttribute(context, TypeFullNames.ALBApiAttribute)) + { + if (context.Compilation.ReferencedAssemblyNames.FirstOrDefault(x => x.Name == "Amazon.Lambda.ApplicationLoadBalancerEvents") == null) + { + diagnosticReporter.Report(Diagnostic.Create(DiagnosticDescriptors.MissingDependencies, methodLocation, "Amazon.Lambda.ApplicationLoadBalancerEvents")); + return false; + } + } + + // Check for references to "Amazon.Lambda.S3Events" if the Lambda method is annotated with S3Event attribute. + if (lambdaMethodSymbol.HasAttribute(context, TypeFullNames.S3EventAttribute)) + { + if (context.Compilation.ReferencedAssemblyNames.FirstOrDefault(x => x.Name == "Amazon.Lambda.S3Events") == null) + { + diagnosticReporter.Report(Diagnostic.Create(DiagnosticDescriptors.MissingDependencies, methodLocation, "Amazon.Lambda.S3Events")); + return false; + } + } + return true; } @@ -106,10 +134,12 @@ private static void ValidateApiGatewayEvents(LambdaFunctionModel lambdaFunctionM diagnostics.Add(Diagnostic.Create(DiagnosticDescriptors.AuthorizerResultOnNonAuthorizerFunction, methodLocation)); } - // If the method does not contain any API or Authorizer events, then it cannot have + // If the method does not contain any API, Authorizer, or ALB events, then it cannot have // parameters that are annotated with HTTP API attributes. // Authorizer functions also support FromHeader, FromQuery, FromRoute attributes. - if (!isApiEvent && !isAuthorizerEvent) + // ALB functions also support FromHeader, FromQuery, FromBody attributes. + var isAlbEvent = lambdaFunctionModel.LambdaMethod.Events.Contains(EventType.ALB); + if (!isApiEvent && !isAuthorizerEvent && !isAlbEvent) { foreach (var parameter in lambdaFunctionModel.LambdaMethod.Parameters) { @@ -268,6 +298,132 @@ private static void ValidateSqsEvents(LambdaFunctionModel lambdaFunctionModel, L } } + private static void ValidateAlbEvents(LambdaFunctionModel lambdaFunctionModel, Location methodLocation, List diagnostics) + { + // If the method does not contain any ALB events, then simply return early + if (!lambdaFunctionModel.LambdaMethod.Events.Contains(EventType.ALB)) + { + return; + } + + // Validate ALBApiAttributes + foreach (var att in lambdaFunctionModel.Attributes) + { + if (att.Type.FullName != TypeFullNames.ALBApiAttribute) + continue; + + var albApiAttribute = ((AttributeModel)att).Data; + var validationErrors = albApiAttribute.Validate(); + validationErrors.ForEach(errorMessage => diagnostics.Add(Diagnostic.Create(DiagnosticDescriptors.InvalidAlbApiAttribute, methodLocation, errorMessage))); + } + + // Validate method parameters + var parameters = lambdaFunctionModel.LambdaMethod.Parameters; + foreach (var parameter in parameters) + { + // [FromRoute] is not supported on ALB functions + if (parameter.Attributes.Any(att => att.Type.FullName == TypeFullNames.FromRouteAttribute)) + { + diagnostics.Add(Diagnostic.Create(DiagnosticDescriptors.FromRouteNotSupportedOnAlb, methodLocation)); + } + + // Validate [FromQuery] parameter types - only primitive types allowed + if (parameter.Attributes.Any(att => att.Type.FullName == TypeFullNames.ALBFromQueryAttribute)) + { + if (!parameter.Type.IsPrimitiveType() && !parameter.Type.IsPrimitiveEnumerableType()) + { + diagnostics.Add(Diagnostic.Create(DiagnosticDescriptors.UnsupportedMethodParameterType, methodLocation, parameter.Name, parameter.Type.FullName)); + } + } + + // Validate attribute names for FromQuery and FromHeader + foreach (var att in parameter.Attributes) + { + var parameterAttributeName = string.Empty; + switch (att.Type.FullName) + { + case TypeFullNames.ALBFromQueryAttribute: + if (att is AttributeModel albFromQueryAttribute) + parameterAttributeName = albFromQueryAttribute.Data.Name; + break; + + case TypeFullNames.ALBFromHeaderAttribute: + if (att is AttributeModel albFromHeaderAttribute) + parameterAttributeName = albFromHeaderAttribute.Data.Name; + break; + + default: + break; + } + + if (!string.IsNullOrEmpty(parameterAttributeName) && !_parameterAttributeNameRegex.IsMatch(parameterAttributeName)) + { + diagnostics.Add(Diagnostic.Create(DiagnosticDescriptors.InvalidParameterAttributeName, methodLocation, parameterAttributeName, parameter.Name)); + } + } + + // Validate that every parameter has a recognized binding + // Allowed: ILambdaContext, ApplicationLoadBalancerRequest, [FromServices], [FromQuery], [FromHeader], [FromBody] + if (parameter.Type.FullName != TypeFullNames.ILambdaContext && + !TypeFullNames.ALBRequests.Contains(parameter.Type.FullName) && + !parameter.Attributes.Any(att => + att.Type.FullName == TypeFullNames.FromServiceAttribute || + att.Type.FullName == TypeFullNames.ALBFromQueryAttribute || + att.Type.FullName == TypeFullNames.ALBFromHeaderAttribute || + att.Type.FullName == TypeFullNames.ALBFromBodyAttribute || + att.Type.FullName == TypeFullNames.FromRouteAttribute)) // FromRoute already has its own error + { + diagnostics.Add(Diagnostic.Create(DiagnosticDescriptors.AlbUnmappedParameter, methodLocation, parameter.Name)); + } + } + } + + private static void ValidateS3Events(LambdaFunctionModel lambdaFunctionModel, Location methodLocation, List diagnostics) + { + if (!lambdaFunctionModel.LambdaMethod.Events.Contains(EventType.S3)) + return; + + // Validate S3EventAttributes + var seenResourceNames = new HashSet(); + foreach (var att in lambdaFunctionModel.Attributes) + { + if (att.Type.FullName != TypeFullNames.S3EventAttribute) + continue; + + var s3EventAttribute = ((AttributeModel)att).Data; + var validationErrors = s3EventAttribute.Validate(); + validationErrors.ForEach(errorMessage => diagnostics.Add(Diagnostic.Create(DiagnosticDescriptors.InvalidS3EventAttribute, methodLocation, errorMessage))); + + // Check for duplicate resource names (only when ResourceName is safe to evaluate) + var derivedResourceName = s3EventAttribute.ResourceName; + if (!string.IsNullOrEmpty(derivedResourceName) && !seenResourceNames.Add(derivedResourceName)) + { + diagnostics.Add(Diagnostic.Create(DiagnosticDescriptors.InvalidS3EventAttribute, methodLocation, + $"Duplicate S3 event resource name '{derivedResourceName}'. Each [S3Event] attribute on the same method must have a unique ResourceName.")); + } + } + + // Validate method parameters - first param must be S3Event, optional second param ILambdaContext + var parameters = lambdaFunctionModel.LambdaMethod.Parameters; + if (parameters.Count == 0 || + parameters.Count > 2 || + (parameters.Count == 1 && parameters[0].Type.FullName != TypeFullNames.S3Event) || + (parameters.Count == 2 && (parameters[0].Type.FullName != TypeFullNames.S3Event || parameters[1].Type.FullName != TypeFullNames.ILambdaContext))) + { + var errorMessage = $"When using the {nameof(S3EventAttribute)}, the Lambda method can accept at most 2 parameters. " + + $"The first parameter is required and must be of type {TypeFullNames.S3Event}. " + + $"The second parameter is optional and must be of type {TypeFullNames.ILambdaContext}."; + diagnostics.Add(Diagnostic.Create(DiagnosticDescriptors.InvalidLambdaMethodSignature, methodLocation, errorMessage)); + } + + // Validate method return type - must be void or Task + if (!lambdaFunctionModel.LambdaMethod.ReturnsVoid && !lambdaFunctionModel.LambdaMethod.ReturnsVoidTask) + { + var errorMessage = $"When using the {nameof(S3EventAttribute)}, the Lambda method can return either void or {TypeFullNames.Task}"; + diagnostics.Add(Diagnostic.Create(DiagnosticDescriptors.InvalidLambdaMethodSignature, methodLocation, errorMessage)); + } + } + private static bool ReportDiagnostics(DiagnosticReporter diagnosticReporter, List diagnostics) { var isValid = true; diff --git a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Writers/CloudFormationWriter.cs b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Writers/CloudFormationWriter.cs index a59aaf6d4..adfa53ae5 100644 --- a/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Writers/CloudFormationWriter.cs +++ b/Libraries/src/Amazon.Lambda.Annotations.SourceGenerator/Writers/CloudFormationWriter.cs @@ -1,8 +1,13 @@ -using Amazon.Lambda.Annotations.APIGateway; +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using Amazon.Lambda.Annotations.ALB; +using Amazon.Lambda.Annotations.APIGateway; using Amazon.Lambda.Annotations.SourceGenerator.Diagnostics; using Amazon.Lambda.Annotations.SourceGenerator.FileIO; using Amazon.Lambda.Annotations.SourceGenerator.Models; using Amazon.Lambda.Annotations.SourceGenerator.Models.Attributes; +using Amazon.Lambda.Annotations.S3; using Amazon.Lambda.Annotations.SQS; using Microsoft.CodeAnalysis; using System; @@ -203,6 +208,8 @@ private void ProcessLambdaFunctionEventAttributes(ILambdaFunctionSerializable la { var currentSyncedEvents = new List(); var currentSyncedEventProperties = new Dictionary>(); + var currentAlbResources = new List(); + var hasFunctionUrl = false; foreach (var attributeModel in lambdaFunction.Attributes) { @@ -221,10 +228,36 @@ private void ProcessLambdaFunctionEventAttributes(ILambdaFunctionSerializable la eventName = ProcessSqsAttribute(lambdaFunction, sqsAttributeModel.Data, currentSyncedEventProperties); currentSyncedEvents.Add(eventName); break; + case AttributeModel albAttributeModel: + var albResourceNames = ProcessAlbApiAttribute(lambdaFunction, albAttributeModel.Data); + currentAlbResources.AddRange(albResourceNames); + break; + case AttributeModel s3AttributeModel: + eventName = ProcessS3Attribute(lambdaFunction, s3AttributeModel.Data, currentSyncedEventProperties); + currentSyncedEvents.Add(eventName); + break; + case AttributeModel functionUrlAttributeModel: + ProcessFunctionUrlAttribute(lambdaFunction, functionUrlAttributeModel.Data); + _templateWriter.SetToken($"Resources.{lambdaFunction.ResourceName}.Metadata.SyncedFunctionUrlConfig", true); + hasFunctionUrl = true; + break; + } + } + + // Remove FunctionUrlConfig only if it was previously created by Annotations (tracked via metadata). + // This preserves any manually-added FunctionUrlConfig that was not created by the source generator. + if (!hasFunctionUrl) + { + var syncedFunctionUrlConfigPath = $"Resources.{lambdaFunction.ResourceName}.Metadata.SyncedFunctionUrlConfig"; + if (_templateWriter.GetToken(syncedFunctionUrlConfigPath, false)) + { + _templateWriter.RemoveToken($"Resources.{lambdaFunction.ResourceName}.Properties.FunctionUrlConfig"); + _templateWriter.RemoveToken(syncedFunctionUrlConfigPath); } } SynchronizeEventsAndProperties(currentSyncedEvents, currentSyncedEventProperties, lambdaFunction); + SynchronizeAlbResources(currentAlbResources, lambdaFunction); } /// @@ -290,6 +323,50 @@ private string ProcessHttpApiAttribute(ILambdaFunctionSerializable lambdaFunctio return eventName; } + /// + /// Writes the configuration to the serverless template. + /// Unlike HttpApi/RestApi, Function URLs are configured as a property on the function resource + /// rather than as an event source. + /// + private void ProcessFunctionUrlAttribute(ILambdaFunctionSerializable lambdaFunction, FunctionUrlAttribute functionUrlAttribute) + { + var functionUrlConfigPath = $"Resources.{lambdaFunction.ResourceName}.Properties.FunctionUrlConfig"; + _templateWriter.SetToken($"{functionUrlConfigPath}.AuthType", functionUrlAttribute.AuthType.ToString()); + + // Always remove the existing Cors block first to clear any stale properties + // from a previous generation pass, then re-emit only the currently configured values. + var corsPath = $"{functionUrlConfigPath}.Cors"; + _templateWriter.RemoveToken(corsPath); + + var hasCors = functionUrlAttribute.AllowOrigins != null + || functionUrlAttribute.AllowMethods != null + || functionUrlAttribute.AllowHeaders != null + || functionUrlAttribute.ExposeHeaders != null + || functionUrlAttribute.AllowCredentials + || functionUrlAttribute.MaxAge > 0; + + if (hasCors) + { + if (functionUrlAttribute.AllowOrigins != null) + _templateWriter.SetToken($"{corsPath}.AllowOrigins", new List(functionUrlAttribute.AllowOrigins), TokenType.List); + + if (functionUrlAttribute.AllowMethods != null) + _templateWriter.SetToken($"{corsPath}.AllowMethods", functionUrlAttribute.AllowMethods.Select(m => m == LambdaHttpMethod.Any ? "*" : m.ToString().ToUpper()).ToList(), TokenType.List); + + if (functionUrlAttribute.AllowHeaders != null) + _templateWriter.SetToken($"{corsPath}.AllowHeaders", new List(functionUrlAttribute.AllowHeaders), TokenType.List); + + if (functionUrlAttribute.ExposeHeaders != null) + _templateWriter.SetToken($"{corsPath}.ExposeHeaders", new List(functionUrlAttribute.ExposeHeaders), TokenType.List); + + if (functionUrlAttribute.AllowCredentials) + _templateWriter.SetToken($"{corsPath}.AllowCredentials", true); + + if (functionUrlAttribute.MaxAge > 0) + _templateWriter.SetToken($"{corsPath}.MaxAge", functionUrlAttribute.MaxAge); + } + } + /// /// Processes all authorizers and writes them to the serverless template as inline authorizers within the API resources. /// AWS SAM expects authorizers to be defined within the Auth.Authorizers property of AWS::Serverless::HttpApi or AWS::Serverless::Api resources. @@ -597,8 +674,285 @@ private string ProcessSqsAttribute(ILambdaFunctionSerializable lambdaFunction, S } /// - /// Writes all properties associated with to the serverless template. + /// Writes all properties associated with to the serverless template. /// + private string ProcessS3Attribute(ILambdaFunctionSerializable lambdaFunction, S3EventAttribute att, Dictionary> syncedEventProperties) + { + var eventName = att.ResourceName; + var eventPath = $"Resources.{lambdaFunction.ResourceName}.Properties.Events.{eventName}"; + + _templateWriter.SetToken($"{eventPath}.Type", "S3"); + + // Bucket - always a Ref since S3 events require the bucket resource in the same template (validated to start with "@") + var bucketName = att.Bucket.Substring(1); + _templateWriter.RemoveToken($"{eventPath}.Properties.Bucket"); + SetEventProperty(syncedEventProperties, lambdaFunction.ResourceName, eventName, $"Bucket.{REF}", bucketName); + + // Events - list of S3 event types (always written since S3 SAM events require it; uses default "s3:ObjectCreated:*" if not explicitly set) + { + var events = att.Events.Split(';').Select(x => x.Trim()).Where(x => !string.IsNullOrWhiteSpace(x)).ToList(); + SetEventProperty(syncedEventProperties, lambdaFunction.ResourceName, eventName, "Events", events, TokenType.List); + } + + // Filter - S3 key filter rules + if (att.IsFilterPrefixSet || att.IsFilterSuffixSet) + { + var rules = new List>(); + + if (att.IsFilterPrefixSet) + { + rules.Add(new Dictionary { { "Name", "prefix" }, { "Value", att.FilterPrefix } }); + } + + if (att.IsFilterSuffixSet) + { + rules.Add(new Dictionary { { "Name", "suffix" }, { "Value", att.FilterSuffix } }); + } + + SetEventProperty(syncedEventProperties, lambdaFunction.ResourceName, eventName, "Filter.S3Key.Rules", rules, TokenType.List); + } + + // Enabled + if (att.IsEnabledSet) + { + SetEventProperty(syncedEventProperties, lambdaFunction.ResourceName, eventName, "Enabled", att.Enabled); + } + + return att.ResourceName; + } + + /// + /// Generates CloudFormation resources for an Application Load Balancer target. + /// Unlike API Gateway events which map to SAM event types, ALB integration requires + /// generating standalone CloudFormation resources: a TargetGroup, a ListenerRule, and a Lambda Permission. + /// + /// List of the three generated CloudFormation resource names for tracking/synchronization. + private List ProcessAlbApiAttribute(ILambdaFunctionSerializable lambdaFunction, ALBApiAttribute att) + { + var baseName = att.IsResourceNameSet ? att.ResourceName : $"{lambdaFunction.ResourceName}ALB"; + var permissionName = $"{baseName}Permission"; + var targetGroupName = $"{baseName}TargetGroup"; + var listenerRuleName = $"{baseName}ListenerRule"; + + // 1. Lambda Permission - allows ELB to invoke the Lambda function + var permPath = $"Resources.{permissionName}"; + if (!_templateWriter.Exists(permPath) || + string.Equals(_templateWriter.GetToken($"{permPath}.Metadata.Tool", string.Empty), CREATION_TOOL, StringComparison.Ordinal)) + { + _templateWriter.SetToken($"{permPath}.Type", "AWS::Lambda::Permission"); + _templateWriter.SetToken($"{permPath}.Metadata.Tool", CREATION_TOOL); + _templateWriter.SetToken($"{permPath}.Properties.FunctionName.{GET_ATTRIBUTE}", new List { lambdaFunction.ResourceName, "Arn" }, TokenType.List); + _templateWriter.SetToken($"{permPath}.Properties.Action", "lambda:InvokeFunction"); + _templateWriter.SetToken($"{permPath}.Properties.Principal", "elasticloadbalancing.amazonaws.com"); + } + + // 2. Target Group - registers the Lambda function as a target + var tgPath = $"Resources.{targetGroupName}"; + if (!_templateWriter.Exists(tgPath) || + string.Equals(_templateWriter.GetToken($"{tgPath}.Metadata.Tool", string.Empty), CREATION_TOOL, StringComparison.Ordinal)) + { + _templateWriter.SetToken($"{tgPath}.Type", "AWS::ElasticLoadBalancingV2::TargetGroup"); + _templateWriter.SetToken($"{tgPath}.Metadata.Tool", CREATION_TOOL); + _templateWriter.SetToken($"{tgPath}.DependsOn", permissionName); + _templateWriter.SetToken($"{tgPath}.Properties.TargetType", "lambda"); + + // MultiValueHeaders must be set via TargetGroupAttributes, not as a top-level property. + // The CFN property "MultiValueHeadersEnabled" does not exist on AWS::ElasticLoadBalancingV2::TargetGroup. + if (att.MultiValueHeaders) + { + _templateWriter.SetToken($"{tgPath}.Properties.TargetGroupAttributes", + new List> + { + new Dictionary + { + { "Key", "lambda.multi_value_headers.enabled" }, + { "Value", "true" } + } + }, TokenType.List); + } + else + { + _templateWriter.RemoveToken($"{tgPath}.Properties.TargetGroupAttributes"); + } + + _templateWriter.SetToken($"{tgPath}.Properties.Targets", new List> + { + new Dictionary + { + { "Id", new Dictionary> { { GET_ATTRIBUTE, new List { lambdaFunction.ResourceName, "Arn" } } } } + } + }, TokenType.List); + } + + // 3. Listener Rule - routes traffic from the ALB listener to the target group + var rulePath = $"Resources.{listenerRuleName}"; + if (!_templateWriter.Exists(rulePath) || + string.Equals(_templateWriter.GetToken($"{rulePath}.Metadata.Tool", string.Empty), CREATION_TOOL, StringComparison.Ordinal)) + { + _templateWriter.SetToken($"{rulePath}.Type", "AWS::ElasticLoadBalancingV2::ListenerRule"); + _templateWriter.SetToken($"{rulePath}.Metadata.Tool", CREATION_TOOL); + + // ListenerArn - handle @reference vs literal ARN + _templateWriter.RemoveToken($"{rulePath}.Properties.ListenerArn"); + if (!string.IsNullOrEmpty(att.ListenerArn) && att.ListenerArn.StartsWith("@")) + { + var refName = att.ListenerArn.Substring(1); + _templateWriter.SetToken($"{rulePath}.Properties.ListenerArn.{REF}", refName); + + // Warn if the referenced resource/parameter doesn't exist in the template + if (!_templateWriter.Exists($"Resources.{refName}") && !_templateWriter.Exists($"{PARAMETERS}.{refName}")) + { + _diagnosticReporter.Report(Diagnostic.Create(DiagnosticDescriptors.AlbListenerReferenceNotFound, Location.None, refName)); + } + } + else + { + _templateWriter.SetToken($"{rulePath}.Properties.ListenerArn", att.ListenerArn); + } + + // Priority + _templateWriter.SetToken($"{rulePath}.Properties.Priority", att.Priority); + + // Conditions + var conditions = new List> + { + new Dictionary + { + { "Field", "path-pattern" }, + { "PathPatternConfig", new Dictionary + { + { "Values", new List { att.PathPattern } } + } + } + } + }; + if (!string.IsNullOrEmpty(att.HostHeader)) + { + conditions.Add(new Dictionary + { + { "Field", "host-header" }, + { "HostHeaderConfig", new Dictionary + { + { "Values", new List { att.HostHeader } } + } + } + }); + } + if (!string.IsNullOrEmpty(att.HttpMethod)) + { + conditions.Add(new Dictionary + { + { "Field", "http-request-method" }, + { "HttpRequestMethodConfig", new Dictionary + { + { "Values", new List { att.HttpMethod.ToUpper() } } + } + } + }); + } + if (!string.IsNullOrEmpty(att.HttpHeaderConditionName) && att.HttpHeaderConditionValues != null && att.HttpHeaderConditionValues.Length > 0) + { + conditions.Add(new Dictionary + { + { "Field", "http-header" }, + { "HttpHeaderConfig", new Dictionary + { + { "HttpHeaderName", att.HttpHeaderConditionName }, + { "Values", att.HttpHeaderConditionValues.ToList() } + } + } + }); + } + if (att.QueryStringConditions != null && att.QueryStringConditions.Length > 0) + { + var keyValuePairs = new List>(); + foreach (var entry in att.QueryStringConditions) + { + var separatorIndex = entry.IndexOf('='); + if (separatorIndex >= 0) + { + var key = entry.Substring(0, separatorIndex); + var value = entry.Substring(separatorIndex + 1); + var kvp = new Dictionary(); + if (!string.IsNullOrEmpty(key)) + { + kvp["Key"] = key; + } + kvp["Value"] = value; + keyValuePairs.Add(kvp); + } + } + if (keyValuePairs.Any()) + { + conditions.Add(new Dictionary + { + { "Field", "query-string" }, + { "QueryStringConfig", new Dictionary + { + { "Values", keyValuePairs } + } + } + }); + } + } + if (att.SourceIpConditions != null && att.SourceIpConditions.Length > 0) + { + conditions.Add(new Dictionary + { + { "Field", "source-ip" }, + { "SourceIpConfig", new Dictionary + { + { "Values", att.SourceIpConditions.ToList() } + } + } + }); + } + _templateWriter.SetToken($"{rulePath}.Properties.Conditions", conditions, TokenType.List); + + // Actions - forward to target group + _templateWriter.SetToken($"{rulePath}.Properties.Actions", new List> + { + new Dictionary + { + { "Type", "forward" }, + { "TargetGroupArn", new Dictionary { { REF, targetGroupName } } } + } + }, TokenType.List); + } + + return new List { permissionName, targetGroupName, listenerRuleName }; + } + + /// + /// Synchronizes ALB resources for a given Lambda function. ALB resources (Permission, TargetGroup, ListenerRule) + /// are standalone top-level CloudFormation resources, so they need separate tracking from SAM events. + /// Previously generated ALB resources that are no longer present in the current compilation are removed. + /// + private void SynchronizeAlbResources(List currentAlbResources, ILambdaFunctionSerializable lambdaFunction) + { + var syncedAlbResourcesPath = $"Resources.{lambdaFunction.ResourceName}.Metadata.SyncedAlbResources"; + + // Get previously synced ALB resources + var previousAlbResources = _templateWriter.GetToken>(syncedAlbResourcesPath, new List()); + + // Remove orphaned ALB resources + var orphanedAlbResources = previousAlbResources.Except(currentAlbResources).ToList(); + foreach (var resourceName in orphanedAlbResources) + { + var resourcePath = $"Resources.{resourceName}"; + // Only remove if it was created by this tool + if (_templateWriter.Exists(resourcePath) && + string.Equals(_templateWriter.GetToken($"{resourcePath}.Metadata.Tool", string.Empty), CREATION_TOOL, StringComparison.Ordinal)) + { + _templateWriter.RemoveToken(resourcePath); + } + } + + // Update synced ALB resources in the template metadata + _templateWriter.RemoveToken(syncedAlbResourcesPath); + if (currentAlbResources.Any()) + _templateWriter.SetToken(syncedAlbResourcesPath, currentAlbResources, TokenType.List); + } /// /// Writes the default values for the Lambda function's metadata and properties. @@ -893,4 +1247,4 @@ private void SynchronizeEventsAndProperties(List syncedEvents, Dictionar _templateWriter.SetToken(syncedEventPropertiesPath, syncedEventProperties, TokenType.KeyVal); } } -} \ No newline at end of file +} diff --git a/Libraries/src/Amazon.Lambda.Annotations/ALB/ALBApiAttribute.cs b/Libraries/src/Amazon.Lambda.Annotations/ALB/ALBApiAttribute.cs new file mode 100644 index 000000000..d73f4e365 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Annotations/ALB/ALBApiAttribute.cs @@ -0,0 +1,188 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.RegularExpressions; + +namespace Amazon.Lambda.Annotations.ALB +{ + /// + /// Configures the Lambda function to be called from an Application Load Balancer. + /// The source generator will create the necessary CloudFormation resources + /// (TargetGroup, ListenerRule, Lambda Permission) to wire the Lambda function + /// as a target behind the specified ALB listener. + /// + /// + /// The listener ARN (or template reference), path pattern, and priority are required. + /// See ALB Lambda documentation. + /// + [AttributeUsage(AttributeTargets.Method)] + public class ALBApiAttribute : Attribute + { + // Only allow alphanumeric characters for resource names + private static readonly Regex _resourceNameRegex = new Regex("^[a-zA-Z0-9]+$"); + + /// + /// The ARN of the existing ALB listener, or a "@ResourceName" reference to a + /// listener resource or parameter defined in the CloudFormation template. + /// To reference a resource in the serverless template, prefix the resource name with "@" symbol. + /// + public string ListenerArn { get; set; } + + /// + /// The path pattern condition for the ALB listener rule (e.g., "/api/orders/*"). + /// ALB supports wildcard path patterns using "*" and "?" characters. + /// + public string PathPattern { get; set; } + + /// + /// The priority of the ALB listener rule. Must be between 1 and 50000. + /// Lower numbers are evaluated first. Each rule on a listener must have a unique priority. + /// + public int Priority { get; set; } + + /// + /// Whether multi-value headers are enabled on the ALB target group. Default: false. + /// When true, the Lambda function should use MultiValueHeaders and + /// MultiValueQueryStringParameters on the request and response objects. + /// When false, use Headers and QueryStringParameters instead. + /// + public bool MultiValueHeaders + { + get => multiValueHeaders.GetValueOrDefault(); + set => multiValueHeaders = value; + } + private bool? multiValueHeaders { get; set; } + internal bool IsMultiValueHeadersSet => multiValueHeaders.HasValue; + + /// + /// Optional host header condition for the listener rule (e.g., "api.example.com"). + /// When specified, the rule will only match requests with this host header value. + /// + public string HostHeader { get; set; } + + /// + /// Optional HTTP method condition for the listener rule (e.g., "GET", "POST"). + /// When specified, the rule will only match requests with this HTTP method. + /// Leave null to match all HTTP methods. + /// + public string HttpMethod { get; set; } + + /// + /// Optional HTTP header name for an http-header listener rule condition (e.g., "X-Environment", "User-Agent"). + /// Must be used together with . + /// The header name is not case-sensitive. + /// + public string HttpHeaderConditionName { get; set; } + + /// + /// Optional HTTP header values for an http-header listener rule condition (e.g., new[] { "dev", "*Chrome*" }). + /// Supports wildcards (* and ?). Must be used together with . + /// Up to 3 match evaluations per condition. + /// + public string[] HttpHeaderConditionValues { get; set; } + + /// + /// Optional query string key/value pairs for a query-string listener rule condition. + /// Format: "key=value" pairs. Use "=value" (empty key) to match any key with that value. + /// Supports wildcards (* and ?). + /// Example: new[] { "version=v1", "=*example*" } + /// + public string[] QueryStringConditions { get; set; } + + /// + /// Optional source IP CIDR blocks for a source-ip listener rule condition. + /// Example: new[] { "192.0.2.0/24", "198.51.100.10/32" } + /// Supports both IPv4 and IPv6 addresses in CIDR format. + /// + public string[] SourceIpConditions { get; set; } + + /// + /// The CloudFormation resource name prefix for the generated ALB resources + /// (TargetGroup, ListenerRule, Permission). Defaults to "{LambdaResourceName}ALB". + /// Must only contain alphanumeric characters. + /// + public string ResourceName + { + get => resourceName; + set => resourceName = value; + } + private string resourceName { get; set; } + internal bool IsResourceNameSet => resourceName != null; + + /// + /// Creates an instance of the class. + /// + /// The ARN of the ALB listener, or a "@ResourceName" reference to a template resource. + /// The path pattern condition (e.g., "/api/orders/*"). + /// The listener rule priority (1-50000). + public ALBApiAttribute(string listenerArn, string pathPattern, int priority) + { + ListenerArn = listenerArn; + PathPattern = pathPattern; + Priority = priority; + } + + /// + /// Validates the attribute properties and returns a list of validation error messages. + /// + internal List Validate() + { + var validationErrors = new List(); + + if (string.IsNullOrEmpty(ListenerArn)) + { + validationErrors.Add($"{nameof(ListenerArn)} is required and cannot be empty."); + } + else if (!ListenerArn.StartsWith("@")) + { + // If it's not a template reference, validate it looks like an ARN + if (!ListenerArn.StartsWith("arn:")) + { + validationErrors.Add($"{nameof(ListenerArn)} = {ListenerArn}. It must be a valid ARN (starting with 'arn:') or a template reference (starting with '@')."); + } + } + + if (string.IsNullOrEmpty(PathPattern)) + { + validationErrors.Add($"{nameof(PathPattern)} is required and cannot be empty."); + } + + if (Priority < 1 || Priority > 50000) + { + validationErrors.Add($"{nameof(Priority)} = {Priority}. It must be between 1 and 50000."); + } + + if (IsResourceNameSet && !_resourceNameRegex.IsMatch(ResourceName)) + { + validationErrors.Add($"{nameof(ResourceName)} = {ResourceName}. It must only contain alphanumeric characters and must not be an empty string."); + } + + if (!string.IsNullOrEmpty(HttpMethod)) + { + var validMethods = new HashSet(StringComparer.OrdinalIgnoreCase) + { + "GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS" + }; + if (!validMethods.Contains(HttpMethod)) + { + validationErrors.Add($"{nameof(HttpMethod)} = {HttpMethod}. It must be a valid HTTP method (GET, POST, PUT, PATCH, DELETE, HEAD, OPTIONS)."); + } + } + + // Validate http-header condition: both name and values must be set together + if (!string.IsNullOrEmpty(HttpHeaderConditionName) && (HttpHeaderConditionValues == null || HttpHeaderConditionValues.Length == 0)) + { + validationErrors.Add($"{nameof(HttpHeaderConditionName)} is set to '{HttpHeaderConditionName}' but {nameof(HttpHeaderConditionValues)} is not set. Both must be specified together."); + } + if ((HttpHeaderConditionValues != null && HttpHeaderConditionValues.Length > 0) && string.IsNullOrEmpty(HttpHeaderConditionName)) + { + validationErrors.Add($"{nameof(HttpHeaderConditionValues)} is set but {nameof(HttpHeaderConditionName)} is not set. Both must be specified together."); + } + + return validationErrors; + } + } +} diff --git a/Libraries/src/Amazon.Lambda.Annotations/ALB/FromBodyAttribute.cs b/Libraries/src/Amazon.Lambda.Annotations/ALB/FromBodyAttribute.cs new file mode 100644 index 000000000..73e01adbe --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Annotations/ALB/FromBodyAttribute.cs @@ -0,0 +1,18 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using System; + +namespace Amazon.Lambda.Annotations.ALB +{ + /// + /// Maps this parameter to the HTTP request body from the ALB request + /// + /// + /// If the parameter is a complex type then the request body will be assumed to be JSON and deserialized into the type. + /// + [AttributeUsage(AttributeTargets.Parameter)] + public class FromBodyAttribute : Attribute + { + } +} diff --git a/Libraries/src/Amazon.Lambda.Annotations/ALB/FromHeaderAttribute.cs b/Libraries/src/Amazon.Lambda.Annotations/ALB/FromHeaderAttribute.cs new file mode 100644 index 000000000..1e8e07cd9 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Annotations/ALB/FromHeaderAttribute.cs @@ -0,0 +1,19 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using System; + +namespace Amazon.Lambda.Annotations.ALB +{ + /// + /// Maps this parameter to an HTTP header value from the ALB request + /// + [AttributeUsage(AttributeTargets.Parameter)] + public class FromHeaderAttribute : Attribute, INamedAttribute + { + /// + /// Name of the header. If not specified, the parameter name is used. + /// + public string Name { get; set; } + } +} diff --git a/Libraries/src/Amazon.Lambda.Annotations/ALB/FromQueryAttribute.cs b/Libraries/src/Amazon.Lambda.Annotations/ALB/FromQueryAttribute.cs new file mode 100644 index 000000000..30d229386 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Annotations/ALB/FromQueryAttribute.cs @@ -0,0 +1,19 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using System; + +namespace Amazon.Lambda.Annotations.ALB +{ + /// + /// Maps this parameter to a query string parameter from the ALB request + /// + [AttributeUsage(AttributeTargets.Parameter)] + public class FromQueryAttribute : Attribute, INamedAttribute + { + /// + /// Name of the query string parameter. If not specified, the parameter name is used. + /// + public string Name { get; set; } + } +} diff --git a/Libraries/src/Amazon.Lambda.Annotations/APIGateway/FunctionUrlAttribute.cs b/Libraries/src/Amazon.Lambda.Annotations/APIGateway/FunctionUrlAttribute.cs new file mode 100644 index 000000000..a92387762 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Annotations/APIGateway/FunctionUrlAttribute.cs @@ -0,0 +1,51 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using System; + +namespace Amazon.Lambda.Annotations.APIGateway +{ + /// + /// Configures the Lambda function to be invoked via a Lambda Function URL. + /// + /// + /// Function URLs use the same payload format as HTTP API v2 (APIGatewayHttpApiV2ProxyRequest/Response). + /// + [AttributeUsage(AttributeTargets.Method)] + public class FunctionUrlAttribute : Attribute + { + /// + public FunctionUrlAuthType AuthType { get; set; } = FunctionUrlAuthType.NONE; + + /// + /// The allowed origins for CORS requests. Example: new[] { "https://example.com" } + /// + public string[] AllowOrigins { get; set; } + + /// + /// The allowed HTTP methods for CORS requests. Example: new[] { LambdaHttpMethod.Get, LambdaHttpMethod.Post } + /// + public LambdaHttpMethod[] AllowMethods { get; set; } + + /// + /// The allowed headers for CORS requests. + /// + public string[] AllowHeaders { get; set; } + + /// + /// Whether credentials are included in the CORS request. + /// + public bool AllowCredentials { get; set; } + + /// + /// The expose headers for CORS responses. + /// + public string[] ExposeHeaders { get; set; } + + /// + /// The maximum time in seconds that a browser can cache the CORS preflight response. + /// A value of 0 means the property is not set. + /// + public int MaxAge { get; set; } + } +} diff --git a/Libraries/src/Amazon.Lambda.Annotations/APIGateway/FunctionUrlAuthType.cs b/Libraries/src/Amazon.Lambda.Annotations/APIGateway/FunctionUrlAuthType.cs new file mode 100644 index 000000000..31a1c2397 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Annotations/APIGateway/FunctionUrlAuthType.cs @@ -0,0 +1,21 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +namespace Amazon.Lambda.Annotations.APIGateway +{ + /// + /// The type of authentication for a Lambda Function URL. + /// + public enum FunctionUrlAuthType + { + /// + /// No authentication. Anyone with the Function URL can invoke the function. + /// + NONE, + + /// + /// IAM authentication. Only authenticated IAM users and roles can invoke the function. + /// + AWS_IAM + } +} diff --git a/Libraries/src/Amazon.Lambda.Annotations/APIGateway/HttpApiAuthorizerAttribute.cs b/Libraries/src/Amazon.Lambda.Annotations/APIGateway/HttpApiAuthorizerAttribute.cs index 17b7d0bf7..725c2842a 100644 --- a/Libraries/src/Amazon.Lambda.Annotations/APIGateway/HttpApiAuthorizerAttribute.cs +++ b/Libraries/src/Amazon.Lambda.Annotations/APIGateway/HttpApiAuthorizerAttribute.cs @@ -8,8 +8,8 @@ namespace Amazon.Lambda.Annotations.APIGateway /// /// /// This attribute must be used in conjunction with the . - /// The authorizer function should return - /// when is true, or + /// The authorizer function should return APIGatewayCustomAuthorizerV2SimpleResponse + /// when is true, or APIGatewayCustomAuthorizerV2IamResponse /// when is false. /// /// @@ -45,8 +45,8 @@ public class HttpApiAuthorizerAttribute : Attribute /// Defaults to true for simpler implementation. /// /// - /// When true, the authorizer should return . - /// When false, the authorizer should return . + /// When true, the authorizer should return APIGatewayCustomAuthorizerV2SimpleResponse. + /// When false, the authorizer should return APIGatewayCustomAuthorizerV2IamResponse. /// public bool EnableSimpleResponses { get; set; } = true; diff --git a/Libraries/src/Amazon.Lambda.Annotations/APIGateway/RestApiAuthorizerAttribute.cs b/Libraries/src/Amazon.Lambda.Annotations/APIGateway/RestApiAuthorizerAttribute.cs index b578e0d97..d6db9fa04 100644 --- a/Libraries/src/Amazon.Lambda.Annotations/APIGateway/RestApiAuthorizerAttribute.cs +++ b/Libraries/src/Amazon.Lambda.Annotations/APIGateway/RestApiAuthorizerAttribute.cs @@ -26,7 +26,7 @@ public enum RestApiAuthorizerType /// /// /// This attribute must be used in conjunction with the . - /// The authorizer function should return . + /// The authorizer function should return APIGatewayCustomAuthorizerResponse. /// /// /// diff --git a/Libraries/src/Amazon.Lambda.Annotations/Amazon.Lambda.Annotations.csproj b/Libraries/src/Amazon.Lambda.Annotations/Amazon.Lambda.Annotations.csproj index 15da7004d..fb365ff5d 100644 --- a/Libraries/src/Amazon.Lambda.Annotations/Amazon.Lambda.Annotations.csproj +++ b/Libraries/src/Amazon.Lambda.Annotations/Amazon.Lambda.Annotations.csproj @@ -11,7 +11,8 @@ ..\..\..\buildtools\public.snk true - 1.10.0 + 1.13.0 + true diff --git a/Libraries/src/Amazon.Lambda.Annotations/README.md b/Libraries/src/Amazon.Lambda.Annotations/README.md index 75dfaac23..45bce6783 100644 --- a/Libraries/src/Amazon.Lambda.Annotations/README.md +++ b/Libraries/src/Amazon.Lambda.Annotations/README.md @@ -19,6 +19,8 @@ Topics: - [Amazon API Gateway example](#amazon-api-gateway-example) - [Amazon S3 example](#amazon-s3-example) - [SQS Event Example](#sqs-event-example) + - [Application Load Balancer (ALB) Example](#application-load-balancer-alb-example) + - [Lambda Function URL Example](#lambda-function-url-example) - [Custom Lambda Authorizer Example](#custom-lambda-authorizer-example) - [HTTP API Authorizer](#http-api-authorizer) - [REST API Authorizer](#rest-api-authorizer) @@ -852,6 +854,330 @@ The following SQS event source mapping will be generated for the `SQSMessageHand } ``` +## Application Load Balancer (ALB) Example + +This example shows how to use the `ALBApi` attribute to configure a Lambda function as a target behind an [Application Load Balancer](https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html). Unlike API Gateway event attributes that map to SAM event types, the ALB integration generates standalone CloudFormation resources — a `TargetGroup`, a `ListenerRule`, and a `Lambda::Permission` — to wire the Lambda function to an existing ALB listener. + +The `ALBApi` attribute contains the following properties: + +| Property | Type | Required | Default | Description | +|---|---|---|---|---| +| `ListenerArn` | `string` | Yes | — | The ARN of the existing ALB listener, or a `@ResourceName` reference to a listener resource defined in the CloudFormation template. | +| `PathPattern` | `string` | Yes | — | The path pattern condition for the listener rule (e.g., `"/api/orders/*"`). Supports wildcard characters `*` and `?`. | +| `Priority` | `int` | Yes | — | The listener rule priority (1–50000). Lower numbers are evaluated first. Must be unique per listener. | +| `MultiValueHeaders` | `bool` | No | `false` | When `true`, enables multi-value headers on the target group. The function should then use `MultiValueHeaders` and `MultiValueQueryStringParameters` on request/response objects. | +| `HostHeader` | `string` | No | `null` | Optional host header condition (e.g., `"api.example.com"`). | +| `HttpMethod` | `string` | No | `null` | Optional HTTP method condition (e.g., `"GET"`, `"POST"`). Leave null to match all methods. | +| `ResourceName` | `string` | No | `"{LambdaResourceName}ALB"` | Custom CloudFormation resource name prefix for the generated resources. Must be alphanumeric. | + +The `ALBApi` attribute must be applied to a Lambda method along with the `LambdaFunction` attribute. + +The Lambda method must conform to the following rules when tagged with the `ALBApi` attribute: + +1. It must have at least 1 argument and can have at most 2 arguments. + - The first argument is required and must be of type `ApplicationLoadBalancerRequest` defined in the [Amazon.Lambda.ApplicationLoadBalancerEvents](https://github.com/aws/aws-lambda-dotnet/tree/master/Libraries/src/Amazon.Lambda.ApplicationLoadBalancerEvents) package. + - The second argument is optional and must be of type `ILambdaContext`. +2. The method return type must be `ApplicationLoadBalancerResponse` or `Task`. + +### Prerequisites + +Your CloudFormation template must include an existing ALB and listener. The `ALBApi` attribute references the listener — it does **not** create the ALB or listener for you. You can define them in the same template or reference one that already exists via its ARN. + +### Basic Example + +This example creates a simple hello endpoint behind an ALB listener that is defined elsewhere in the template: + +```csharp +using Amazon.Lambda.Annotations; +using Amazon.Lambda.Annotations.ALB; +using Amazon.Lambda.ApplicationLoadBalancerEvents; +using Amazon.Lambda.Core; +using System.Collections.Generic; + +public class ALBFunctions +{ + [LambdaFunction(ResourceName = "ALBHello", MemorySize = 256, Timeout = 15)] + [ALBApi("@ALBTestListener", "/hello", 1)] + public ApplicationLoadBalancerResponse Hello(ApplicationLoadBalancerRequest request, ILambdaContext context) + { + context.Logger.LogInformation($"Hello endpoint hit. Path: {request.Path}"); + + return new ApplicationLoadBalancerResponse + { + StatusCode = 200, + StatusDescription = "200 OK", + IsBase64Encoded = false, + Headers = new Dictionary + { + { "Content-Type", "application/json" } + }, + Body = $"{{\"message\": \"Hello from ALB Lambda!\", \"path\": \"{request.Path}\"}}" + }; + } +} +``` + +In the example above, `@ALBTestListener` references a listener resource called `ALBTestListener` defined in the same CloudFormation template. The `@` prefix tells the source generator to use a `Ref` intrinsic function instead of a literal ARN string. + +### Using a Literal Listener ARN + +If you want to reference an ALB listener in a different stack or one that was created outside of CloudFormation, use the full ARN: + +```csharp +[LambdaFunction(ResourceName = "ALBHandler")] +[ALBApi("arn:aws:elasticloadbalancing:us-east-1:123456789012:listener/app/my-alb/abc123/def456", "/api/*", 10)] +public ApplicationLoadBalancerResponse HandleRequest(ApplicationLoadBalancerRequest request, ILambdaContext context) +{ + return new ApplicationLoadBalancerResponse + { + StatusCode = 200, + Headers = new Dictionary { { "Content-Type", "application/json" } }, + Body = "{\"status\": \"ok\"}" + }; +} +``` + +### Advanced Example with All Options + +This example shows all optional properties including host header filtering, HTTP method filtering, multi-value headers, and a custom resource name: + +```csharp +[LambdaFunction(ResourceName = "ALBOrders")] +[ALBApi("@MyListener", "/api/orders/*", 5, + MultiValueHeaders = true, + HostHeader = "api.example.com", + HttpMethod = "POST", + ResourceName = "OrdersALB")] +public ApplicationLoadBalancerResponse CreateOrder(ApplicationLoadBalancerRequest request, ILambdaContext context) +{ + // When MultiValueHeaders is true, use MultiValueHeaders and MultiValueQueryStringParameters + var contentTypes = request.MultiValueHeaders?["content-type"]; + + return new ApplicationLoadBalancerResponse + { + StatusCode = 201, + StatusDescription = "201 Created", + MultiValueHeaders = new Dictionary> + { + { "Content-Type", new List { "application/json" } }, + { "X-Custom-Header", new List { "value1", "value2" } } + }, + Body = "{\"orderId\": \"12345\"}" + }; +} +``` + +### Generated CloudFormation Resources + +For each `ALBApi` attribute, the source generator creates three CloudFormation resources. Here is an example of the generated template for the basic hello endpoint: + +```json +"ALBHello": { + "Type": "AWS::Serverless::Function", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations" + }, + "Properties": { + "Runtime": "dotnet8", + "CodeUri": ".", + "MemorySize": 512, + "Timeout": 15, + "Policies": ["AWSLambdaBasicExecutionRole"], + "PackageType": "Zip", + "Handler": "MyProject::MyNamespace.ALBFunctions_Hello_Generated::Hello" + } +}, +"ALBHelloALBPermission": { + "Type": "AWS::Lambda::Permission", + "Metadata": { "Tool": "Amazon.Lambda.Annotations" }, + "Properties": { + "FunctionName": { "Fn::GetAtt": ["ALBHello", "Arn"] }, + "Action": "lambda:InvokeFunction", + "Principal": "elasticloadbalancing.amazonaws.com" + } +}, +"ALBHelloALBTargetGroup": { + "Type": "AWS::ElasticLoadBalancingV2::TargetGroup", + "Metadata": { "Tool": "Amazon.Lambda.Annotations" }, + "DependsOn": "ALBHelloALBPermission", + "Properties": { + "TargetType": "lambda", + "Targets": [ + { "Id": { "Fn::GetAtt": ["ALBHello", "Arn"] } } + ] + } +}, +"ALBHelloALBListenerRule": { + "Type": "AWS::ElasticLoadBalancingV2::ListenerRule", + "Metadata": { "Tool": "Amazon.Lambda.Annotations" }, + "Properties": { + "ListenerArn": { "Ref": "ALBTestListener" }, + "Priority": 1, + "Conditions": [ + { "Field": "path-pattern", "Values": ["/hello"] } + ], + "Actions": [ + { "Type": "forward", "TargetGroupArn": { "Ref": "ALBHelloALBTargetGroup" } } + ] + } +} +``` + +When `MultiValueHeaders` is set to `true`, the target group will include a `TargetGroupAttributes` section: + +```json +"TargetGroupAttributes": [ + { "Key": "lambda.multi_value_headers.enabled", "Value": "true" } +] +``` + +When `HostHeader` or `HttpMethod` are specified, additional conditions are added to the listener rule: + +```json +"Conditions": [ + { "Field": "path-pattern", "Values": ["/api/orders/*"] }, + { "Field": "host-header", "Values": ["api.example.com"] }, + { "Field": "http-request-method", "Values": ["POST"] } +] +``` + +### Setting Up the ALB in the Template + +The `ALBApi` attribute requires an existing ALB listener. Here is a minimal example of the infrastructure resources you would add to your `serverless.template`: + +```json +{ + "MyVPC": { "Type": "AWS::EC2::VPC", "Properties": { "CidrBlock": "10.0.0.0/16" } }, + "MySubnet1": { "Type": "AWS::EC2::Subnet", "Properties": { "VpcId": { "Ref": "MyVPC" }, "CidrBlock": "10.0.1.0/24" } }, + "MySubnet2": { "Type": "AWS::EC2::Subnet", "Properties": { "VpcId": { "Ref": "MyVPC" }, "CidrBlock": "10.0.2.0/24" } }, + "MySecurityGroup": { "Type": "AWS::EC2::SecurityGroup", "Properties": { "GroupDescription": "ALB SG", "VpcId": { "Ref": "MyVPC" } } }, + "MyALB": { + "Type": "AWS::ElasticLoadBalancingV2::LoadBalancer", + "Properties": { + "Type": "application", + "Scheme": "internet-facing", + "Subnets": [{ "Ref": "MySubnet1" }, { "Ref": "MySubnet2" }], + "SecurityGroups": [{ "Ref": "MySecurityGroup" }] + } + }, + "MyListener": { + "Type": "AWS::ElasticLoadBalancingV2::Listener", + "Properties": { + "LoadBalancerArn": { "Ref": "MyALB" }, + "Port": 80, + "Protocol": "HTTP", + "DefaultActions": [{ "Type": "fixed-response", "FixedResponseConfig": { "StatusCode": "404" } }] + } + } +} +``` + +Then your Lambda function references `@MyListener` in the `ALBApi` attribute. + +## Lambda Function URL Example + +[Lambda Function URLs](https://docs.aws.amazon.com/lambda/latest/dg/lambda-urls.html) provide a dedicated HTTPS endpoint for your Lambda function without needing API Gateway or an Application Load Balancer. The `FunctionUrl` attribute configures the function to be invoked via a Function URL. Function URLs use the same payload format as HTTP API v2 (`APIGatewayHttpApiV2ProxyRequest`/`APIGatewayHttpApiV2ProxyResponse`). + +The `FunctionUrl` attribute contains the following properties: + +| Property | Type | Required | Default | Description | +|---|---|---|---|---| +| `AuthType` | `FunctionUrlAuthType` | No | `NONE` | The authentication type: `NONE` (public) or `AWS_IAM` (IAM-authenticated). | +| `AllowOrigins` | `string[]` | No | `null` | Allowed origins for CORS requests (e.g., `new[] { "https://example.com" }`). | +| `AllowMethods` | `LambdaHttpMethod[]` | No | `null` | Allowed HTTP methods for CORS requests, using the `LambdaHttpMethod` enum (e.g., `new[] { LambdaHttpMethod.Get, LambdaHttpMethod.Post }`). | +| `AllowHeaders` | `string[]` | No | `null` | Allowed headers for CORS requests. | +| `ExposeHeaders` | `string[]` | No | `null` | Headers to expose in CORS responses. | +| `AllowCredentials` | `bool` | No | `false` | Whether credentials are included in the CORS request. | +| `MaxAge` | `int` | No | `0` | Maximum time in seconds that a browser can cache the CORS preflight response. `0` means not set. | + +### Basic Example + +A simple function with a public Function URL (no authentication): + +```csharp +using Amazon.Lambda.Annotations; +using Amazon.Lambda.Annotations.APIGateway; +using Amazon.Lambda.Core; + +public class Functions +{ + [LambdaFunction(PackageType = LambdaPackageType.Image)] + [FunctionUrl(AuthType = FunctionUrlAuthType.NONE)] + public IHttpResult GetItems([FromQuery] string category, ILambdaContext context) + { + context.Logger.LogLine($"Getting items for category: {category}"); + return HttpResults.Ok(new { items = new[] { "item1", "item2" }, category }); + } +} +``` + +### With IAM Authentication + +Use `FunctionUrlAuthType.AWS_IAM` to require IAM authentication for the Function URL: + +```csharp +[LambdaFunction(PackageType = LambdaPackageType.Image)] +[FunctionUrl(AuthType = FunctionUrlAuthType.AWS_IAM)] +public IHttpResult SecureEndpoint(ILambdaContext context) +{ + return HttpResults.Ok(new { message = "This endpoint requires IAM auth" }); +} +``` + +### With CORS Configuration + +Configure CORS settings directly on the attribute. The `AllowMethods` property uses the type-safe `LambdaHttpMethod` enum, consistent with the `HttpApi` and `RestApi` attributes: + +```csharp +[LambdaFunction(PackageType = LambdaPackageType.Image)] +[FunctionUrl( + AuthType = FunctionUrlAuthType.NONE, + AllowOrigins = new[] { "https://example.com", "https://app.example.com" }, + AllowMethods = new[] { LambdaHttpMethod.Get, LambdaHttpMethod.Post }, + AllowHeaders = new[] { "Content-Type", "Authorization" }, + AllowCredentials = true, + MaxAge = 3600)] +public IHttpResult GetData([FromQuery] string id, ILambdaContext context) +{ + return HttpResults.Ok(new { id, data = "some data" }); +} +``` + +### Generated CloudFormation + +The source generator creates a `FunctionUrlConfig` property on the Lambda function resource (not a SAM event source). Here is an example with CORS: + +```json +"GetDataFunction": { + "Type": "AWS::Serverless::Function", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations", + "SyncedFunctionUrlConfig": true + }, + "Properties": { + "PackageType": "Image", + "ImageUri": ".", + "ImageConfig": { + "Command": ["MyAssembly::MyNamespace.Functions_GetData_Generated::GetData"] + }, + "MemorySize": 512, + "Timeout": 30, + "FunctionUrlConfig": { + "AuthType": "NONE", + "Cors": { + "AllowOrigins": ["https://example.com", "https://app.example.com"], + "AllowMethods": ["GET", "POST"], + "AllowHeaders": ["Content-Type", "Authorization"], + "AllowCredentials": true, + "MaxAge": 3600 + } + } + } +} +``` + +> **Note:** Unlike `HttpApi` and `RestApi` which create SAM event sources, `FunctionUrl` configures the `FunctionUrlConfig` property directly on the function resource. If the `FunctionUrl` attribute is removed from the code, the source generator will automatically clean up the `FunctionUrlConfig` from the CloudFormation template. + ## Custom Lambda Authorizer Example Lambda Annotations supports defining custom Lambda authorizers using attributes. Custom authorizers let you control access to your API Gateway endpoints by running a Lambda function that validates tokens or request parameters before the target function is invoked. The source generator automatically wires up the authorizer resources and references in the CloudFormation template. @@ -1198,7 +1524,11 @@ parameter to the `LambdaFunction` must be the event object and the event source * RestApiAuthorizer * Marks a Lambda function as a REST API (API Gateway V1) custom authorizer. The authorizer name is automatically derived from the method name. Other functions reference it via `RestApi.Authorizer` using `nameof()`. Use the `Type` property to choose between `Token` and `Request` authorizer types. * SQSEvent - * Sets up event source mapping between the Lambda function and SQS queues. The SQS queue ARN is required to be set on the attribute. If users want to pass a reference to an existing SQS queue resource defined in their CloudFormation template, they can pass the SQS queue resource name prefixed with the '@' symbol. + * Sets up event source mapping between the Lambda function and SQS queues. The SQS queue ARN is required to be set on the attribute. If users want to pass a reference to an existing SQS queue resource defined in their CloudFormation template, they can pass the SQS queue resource name prefixed with the '@' symbol. +* ALBApi + * Configures the Lambda function to be called from an Application Load Balancer. The listener ARN (or `@ResourceName` template reference), path pattern, and priority are required. The source generator creates standalone CloudFormation resources (TargetGroup, ListenerRule, Lambda Permission) rather than SAM event types. +* FunctionUrl + * Configures the Lambda function to be invoked via a Lambda Function URL. Supports `AuthType` (`NONE` or `AWS_IAM`) and CORS configuration including `AllowMethods` (using the `LambdaHttpMethod` enum), `AllowOrigins`, `AllowHeaders`, `AllowCredentials`, and `MaxAge`. The source generator writes a `FunctionUrlConfig` property on the function resource rather than a SAM event source. ### Parameter Attributes @@ -1277,3 +1607,5 @@ The content type is determined using the following rules. ## Project References If API Gateway event attributes, such as `RestAPI` or `HttpAPI`, are being used then a package reference to `Amazon.Lambda.APIGatewayEvents` must be added to the project, otherwise the project will not compile. We do not include it by default in order to keep the `Amazon.Lambda.Annotations` library lightweight. + +Similarly, if the `ALBApi` attribute is being used then a package reference to `Amazon.Lambda.ApplicationLoadBalancerEvents` must be added to the project. This provides the `ApplicationLoadBalancerRequest` and `ApplicationLoadBalancerResponse` types used by ALB Lambda functions. diff --git a/Libraries/src/Amazon.Lambda.Annotations/S3/S3EventAttribute.cs b/Libraries/src/Amazon.Lambda.Annotations/S3/S3EventAttribute.cs new file mode 100644 index 000000000..13bc79095 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Annotations/S3/S3EventAttribute.cs @@ -0,0 +1,129 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using System; +using System.Collections.Generic; +using System.Text.RegularExpressions; + +namespace Amazon.Lambda.Annotations.S3 +{ + /// + /// This attribute defines the S3 event source configuration for a Lambda function. + /// + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + public class S3EventAttribute : Attribute + { + private static readonly Regex _resourceNameRegex = new Regex("^[a-zA-Z0-9]+$"); + + /// + /// The S3 bucket that will act as the event trigger for the Lambda function. + /// This must be a reference to an S3 bucket resource defined in the serverless template, prefixed with "@". + /// + public string Bucket { get; set; } + + /// + /// The CloudFormation resource name for the S3 event. By default this is derived from the Bucket reference without the "@" prefix. + /// + public string ResourceName + { + get + { + if (IsResourceNameSet) + return resourceName; + if (!string.IsNullOrEmpty(Bucket) && Bucket.StartsWith("@")) + return Bucket.Substring(1); + return Bucket; + } + set => resourceName = value; + } + private string resourceName = null; + internal bool IsResourceNameSet => resourceName != null; + + /// + /// Semicolon-separated list of S3 event types. Default is 's3:ObjectCreated:*'. + /// + public string Events + { + get => events ?? "s3:ObjectCreated:*"; + set => events = value; + } + private string events = null; + internal bool IsEventsSet => events != null; + + /// + /// S3 key prefix filter for the event notification. + /// + public string FilterPrefix + { + get => filterPrefix; + set => filterPrefix = value; + } + private string filterPrefix = null; + internal bool IsFilterPrefixSet => filterPrefix != null; + + /// + /// S3 key suffix filter for the event notification. + /// + public string FilterSuffix + { + get => filterSuffix; + set => filterSuffix = value; + } + private string filterSuffix = null; + internal bool IsFilterSuffixSet => filterSuffix != null; + + /// + /// If set to false, the event source will be disabled. Default value is true. + /// + public bool Enabled + { + get => enabled.GetValueOrDefault(true); + set => enabled = value; + } + private bool? enabled; + internal bool IsEnabledSet => enabled.HasValue; + + /// + /// Creates an instance of the class. + /// + /// property + public S3EventAttribute(string bucket) + { + Bucket = bucket; + } + + internal List Validate() + { + var validationErrors = new List(); + + if (string.IsNullOrEmpty(Bucket)) + { + validationErrors.Add($"{nameof(S3EventAttribute.Bucket)} is required and must not be empty"); + } + else if (!Bucket.StartsWith("@")) + { + validationErrors.Add($"{nameof(S3EventAttribute.Bucket)} = {Bucket}. S3 event sources require a reference to an S3 bucket resource in the serverless template. Prefix the resource name with '@'"); + } + else + { + var bucketResourceName = Bucket.Substring(1); + if (!_resourceNameRegex.IsMatch(bucketResourceName)) + { + validationErrors.Add($"{nameof(S3EventAttribute.Bucket)} = {Bucket}. The referenced S3 bucket resource name must not be empty and must only contain alphanumeric characters after the '@' prefix"); + } + } + + if (IsResourceNameSet && !_resourceNameRegex.IsMatch(ResourceName)) + { + validationErrors.Add($"{nameof(S3EventAttribute.ResourceName)} = {ResourceName}. It must only contain alphanumeric characters and must not be an empty string"); + } + + if (string.IsNullOrEmpty(Events)) + { + validationErrors.Add($"{nameof(S3EventAttribute.Events)} must not be empty"); + } + + return validationErrors; + } + } +} diff --git a/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/HttpResponseStreamPrelude.cs b/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/HttpResponseStreamPrelude.cs new file mode 100644 index 000000000..67eb9d3ae --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/HttpResponseStreamPrelude.cs @@ -0,0 +1,95 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +#if NET8_0_OR_GREATER +using System; +using System.Collections.Generic; +using System.Net; +using System.Runtime.Versioning; +using System.Text.Json; + +namespace Amazon.Lambda.Core.ResponseStreaming +{ + /// + /// The HTTP response prelude to be sent as the first chunk of a streaming response when using . + /// + [RequiresPreviewFeatures(LambdaResponseStreamFactory.PreviewMessage)] + public class HttpResponseStreamPrelude + { + /// + /// The Http status code to include in the response prelude. + /// + public HttpStatusCode? StatusCode { get; set; } + + /// + /// The response headers to include in the response prelude. This collection supports setting single value for the same headers. + /// + public IDictionary Headers { get; set; } = new Dictionary(); + + /// + /// The response headers to include in the response prelude. This collection supports setting multiple values for the same headers. + /// + public IDictionary> MultiValueHeaders { get; set; } = new Dictionary>(); + + /// + /// The list of cookies to include in the response prelude. This is used for Lambda Function URL responses, which support a separate "cookies" field in the response JSON for setting cookies, rather than requiring cookies to be set via the "Set-Cookie" header. + /// + public IList Cookies { get; set; } = new List(); + + internal byte[] ToByteArray() + { + var bufferWriter = new System.Buffers.ArrayBufferWriter(); + using (var writer = new Utf8JsonWriter(bufferWriter)) + { + writer.WriteStartObject(); + + if (StatusCode.HasValue) + writer.WriteNumber("statusCode", (int)StatusCode); + + if (Headers?.Count > 0) + { + writer.WriteStartObject("headers"); + foreach (var header in Headers) + { + writer.WriteString(header.Key, header.Value); + } + writer.WriteEndObject(); + } + + if (MultiValueHeaders?.Count > 0) + { + writer.WriteStartObject("multiValueHeaders"); + foreach (var header in MultiValueHeaders) + { + writer.WriteStartArray(header.Key); + foreach (var value in header.Value) + { + writer.WriteStringValue(value); + } + writer.WriteEndArray(); + } + writer.WriteEndObject(); + } + + if (Cookies?.Count > 0) + { + writer.WriteStartArray("cookies"); + foreach (var cookie in Cookies) + { + writer.WriteStringValue(cookie); + } + writer.WriteEndArray(); + } + + writer.WriteEndObject(); + } + + if (string.Equals(Environment.GetEnvironmentVariable("LAMBDA_NET_SERIALIZER_DEBUG"), "true", StringComparison.OrdinalIgnoreCase)) + { + LambdaLogger.Log(LogLevel.Information, "HTTP Response Stream Prelude JSON: {Prelude}", System.Text.Encoding.UTF8.GetString(bufferWriter.WrittenSpan)); + } + + return bufferWriter.WrittenSpan.ToArray(); + } + } +} +#endif diff --git a/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/ILambdaResponseStream.cs b/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/ILambdaResponseStream.cs new file mode 100644 index 000000000..1385e551e --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/ILambdaResponseStream.cs @@ -0,0 +1,40 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +#if NET8_0_OR_GREATER +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Amazon.Lambda.Core.ResponseStreaming +{ + /// + /// Interface for writing streaming responses in AWS Lambda functions. + /// Obtained by calling within a handler. + /// + internal interface ILambdaResponseStream : IDisposable + { + /// + /// Asynchronously writes a portion of a byte array to the response stream. + /// + /// The byte array containing data to write. + /// The zero-based byte offset in buffer at which to begin copying bytes. + /// The number of bytes to write. + /// Optional cancellation token. + /// A task representing the asynchronous operation. + /// Thrown if the stream is already completed or an error has been reported. + Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default); + + + /// + /// Gets the total number of bytes written to the stream so far. + /// + long BytesWritten { get; } + + + /// + /// Gets whether an error has been reported. + /// + bool HasError { get; } + } +} +#endif diff --git a/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/LambdaResponseStream.cs b/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/LambdaResponseStream.cs new file mode 100644 index 000000000..83ac446a4 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/LambdaResponseStream.cs @@ -0,0 +1,123 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +#if NET8_0_OR_GREATER + +using System; +using System.IO; +using System.Runtime.Versioning; +using System.Threading; +using System.Threading.Tasks; + +namespace Amazon.Lambda.Core.ResponseStreaming +{ + /// + /// A write-only, non-seekable subclass that streams response data + /// to the Lambda Runtime API. Returned by . + /// Integrates with standard .NET stream consumers such as . + /// + [RequiresPreviewFeatures(LambdaResponseStreamFactory.PreviewMessage)] + public class LambdaResponseStream : Stream + { + private readonly ILambdaResponseStream _responseStream; + + internal LambdaResponseStream(ILambdaResponseStream responseStream) + { + _responseStream = responseStream; + } + + /// + /// The number of bytes written to the Lambda response stream so far. + /// + public long BytesWritten => _responseStream.BytesWritten; + + /// + /// Asynchronously writes a byte array to the response stream. + /// + /// The byte array to write. + /// Optional cancellation token. + /// A task representing the asynchronous operation. + /// Thrown if the stream is already completed or an error has been reported. + public async Task WriteAsync(byte[] buffer, CancellationToken cancellationToken = default) + { + if (buffer == null) + throw new ArgumentNullException(nameof(buffer)); + + await WriteAsync(buffer, 0, buffer.Length, cancellationToken); + } + + /// + /// Asynchronously writes a portion of a byte array to the response stream. + /// + /// The byte array containing data to write. + /// The zero-based byte offset in buffer at which to begin copying bytes. + /// The number of bytes to write. + /// Optional cancellation token. + /// A task representing the asynchronous operation. + /// Thrown if the stream is already completed or an error has been reported. + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) + { + await _responseStream.WriteAsync(buffer, offset, count, cancellationToken); + } + + #region Noop Overrides + + /// Gets a value indicating whether the stream supports reading. Always false. + public override bool CanRead => false; + + /// Gets a value indicating whether the stream supports seeking. Always false. + public override bool CanSeek => false; + + /// Gets a value indicating whether the stream supports writing. Always true. + public override bool CanWrite => true; + + /// + /// Gets the total number of bytes written to the stream so far. + /// Equivalent to . + /// + public override long Length => BytesWritten; + + /// + /// Getting or setting the position is not supported. + /// + /// Always thrown. + public override long Position + { + get => throw new NotSupportedException($"{nameof(LambdaResponseStream)} does not support seeking."); + set => throw new NotSupportedException($"{nameof(LambdaResponseStream)} does not support seeking."); + } + + /// Not supported. + /// Always thrown. + public override long Seek(long offset, SeekOrigin origin) + => throw new NotImplementedException($"{nameof(LambdaResponseStream)} does not support seeking."); + + /// Not supported. + /// Always thrown. + public override int Read(byte[] buffer, int offset, int count) + => throw new NotImplementedException($"{nameof(LambdaResponseStream)} does not support reading."); + + /// Not supported. + /// Always thrown. + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => throw new NotImplementedException($"{nameof(LambdaResponseStream)} does not support reading."); + + /// + /// Writes a sequence of bytes to the stream. Delegates to the async path synchronously. + /// Prefer to avoid blocking. + /// + public override void Write(byte[] buffer, int offset, int count) + => WriteAsync(buffer, offset, count, CancellationToken.None).GetAwaiter().GetResult(); + + /// + /// Flush is a no-op; data is sent to the Runtime API immediately on each write. + /// + public override void Flush() { } + + /// Not supported. + /// Always thrown. + public override void SetLength(long value) + => throw new NotSupportedException($"{nameof(LambdaResponseStream)} does not support SetLength."); + #endregion + } +} +#endif diff --git a/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/LambdaResponseStreamFactory.cs b/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/LambdaResponseStreamFactory.cs new file mode 100644 index 000000000..1b9e6d3b6 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/LambdaResponseStreamFactory.cs @@ -0,0 +1,72 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +#if NET8_0_OR_GREATER +using System; +using System.IO; +using System.Runtime.Versioning; + +namespace Amazon.Lambda.Core.ResponseStreaming +{ + /// + /// Factory to create Lambda response streams for writing streaming responses in AWS Lambda functions. The created streams are write-only and non-seekable. + /// + [RequiresPreviewFeatures(LambdaResponseStreamFactory.PreviewMessage)] + public class LambdaResponseStreamFactory + { + internal const string PreviewMessage = + "Response streaming is in preview till a new version of .NET Lambda runtime client that supports response streaming " + + "has been deployed to the .NET Lambda managed runtime. Till deployment has been made the feature can be used by deploying as an " + + "executable including the latest version of Amazon.Lambda.RuntimeSupport and setting the \"EnablePreviewFeatures\" in the Lambda " + + "project file to \"true\""; + + internal const string UninitializedFactoryMessage = + "LambdaResponseStreamFactory is not initialized. This is caused by mismatch versions of Amazon.Lambda.Core and Amazon.Lambda.RuntimeSupport. " + + "Update both packages to the current version to address the issue."; + + private static Func _streamFactory; + + internal static void SetLambdaResponseStream(Func streamFactory) + { + _streamFactory = streamFactory ?? throw new ArgumentNullException(nameof(streamFactory)); + } + + /// + /// Creates a a subclass of that can be used to write streaming responses back to callers of the Lambda function. Once + /// a Lambda function creates a response stream all output must be returned by writing to the stream; the Lambda function's handler + /// return value will be ignored. The stream is write-only and non-seekable. + /// + /// + public static LambdaResponseStream CreateStream() + { + if (_streamFactory == null) + throw new InvalidOperationException(UninitializedFactoryMessage); + + var runtimeResponseStream = _streamFactory(Array.Empty()); + return new LambdaResponseStream(runtimeResponseStream); + } + + /// + /// Creates a a subclass of for writing streaming responses, with an HTTP response prelude containing status code and headers. This should be used for + /// Lambda functions using response streaming that are invoked via the Lambda Function URLs or API Gateway HTTP APIs, where the response format is expected to be an HTTP response. + /// The prelude will be serialized and sent as the first chunk of the response stream, and should contain any necessary HTTP status code and headers. + /// + /// Once a Lambda function creates a response stream all output must be returned by writing to the stream; the Lambda function's handler + /// return value will be ignored. The stream is write-only and non-seekable. + /// + /// + /// The HTTP response prelude including status code and headers. + /// + public static LambdaResponseStream CreateHttpStream(HttpResponseStreamPrelude prelude) + { + if (_streamFactory == null) + throw new InvalidOperationException(UninitializedFactoryMessage); + + if (prelude is null) + throw new ArgumentNullException(nameof(prelude)); + + var runtimeResponseStream = _streamFactory(prelude.ToByteArray()); + return new LambdaResponseStream(runtimeResponseStream); + } + } +} +#endif diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Amazon.Lambda.RuntimeSupport.csproj b/Libraries/src/Amazon.Lambda.RuntimeSupport/Amazon.Lambda.RuntimeSupport.csproj index b3bfb0488..6f8dabfa2 100644 --- a/Libraries/src/Amazon.Lambda.RuntimeSupport/Amazon.Lambda.RuntimeSupport.csproj +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Amazon.Lambda.RuntimeSupport.csproj @@ -4,7 +4,7 @@ netstandard2.0;net6.0;net8.0;net9.0;net10.0;net11.0 - 1.14.2 + 1.14.3 Provides a bootstrap and Lambda Runtime API Client to help you to develop custom .NET Core Lambda Runtimes. Amazon.Lambda.RuntimeSupport Amazon.Lambda.RuntimeSupport diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/LambdaBootstrap.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/LambdaBootstrap.cs index 0e00f3e7f..bb6198d9e 100644 --- a/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/LambdaBootstrap.cs +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/LambdaBootstrap.cs @@ -20,6 +20,7 @@ using System.Threading; using System.Threading.Tasks; using Amazon.Lambda.RuntimeSupport.Bootstrap; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; using Amazon.Lambda.RuntimeSupport.Helpers; namespace Amazon.Lambda.RuntimeSupport @@ -225,6 +226,19 @@ internal LambdaBootstrap(HttpClient httpClient, LambdaBootstrapHandler handler, return; } #if NET8_0_OR_GREATER + + try + { + // Initalize in Amazon.Lambda.Core the factory for creating the response stream and related logic for supporting response streaming. + ResponseStreamLambdaCoreInitializerIsolated.InitializeCore(); + } + catch (TypeLoadException) + { + _logger.LogDebug("Failed to configure Amazon.Lambda.Core with factory to create response stream. This happens when the version of Amazon.Lambda.Core referenced by the Lambda function is out of date."); + } + + + // Check if Initialization type is SnapStart, and invoke the snapshot restore logic. if (_configuration.IsInitTypeSnapstart) { @@ -349,6 +363,7 @@ internal async Task InvokeOnceAsync(CancellationToken cancellationToken = defaul _logger.LogInformation("Starting InvokeOnceAsync"); var invocation = await Client.GetNextInvocationAsync(cancellationToken); + var isMultiConcurrency = Utils.IsUsingMultiConcurrency(_environmentVariables); Func processingFunc = async () => { @@ -358,6 +373,17 @@ internal async Task InvokeOnceAsync(CancellationToken cancellationToken = defaul SetInvocationTraceId(impl.RuntimeApiHeaders.TraceId); } + // Initialize ResponseStreamFactory — includes RuntimeApiClient reference + var runtimeApiClient = Client as RuntimeApiClient; + if (runtimeApiClient != null) + { + ResponseStreamFactory.InitializeInvocation( + invocation.LambdaContext.AwsRequestId, + isMultiConcurrency, + runtimeApiClient, + cancellationToken); + } + try { InvocationResponse response = null; @@ -372,15 +398,41 @@ internal async Task InvokeOnceAsync(CancellationToken cancellationToken = defaul catch (Exception exception) { WriteUnhandledExceptionToLog(exception); - await Client.ReportInvocationErrorAsync(invocation.LambdaContext.AwsRequestId, exception, cancellationToken); + + var responseStream = ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency); + if (responseStream != null) + { + responseStream.ReportError(exception); + } + else + { + await Client.ReportInvocationErrorAsync(invocation.LambdaContext.AwsRequestId, exception, cancellationToken); + } } finally { _logger.LogInformation("Finished invoking handler"); } - if (invokeSucceeded) + var streamIfCreated = ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency); + if (streamIfCreated != null) + { + streamIfCreated.MarkCompleted(); + + // If streaming was started, await the HTTP send task to ensure it completes + var sendTask = ResponseStreamFactory.GetSendTask(isMultiConcurrency); + if (sendTask != null) + { + // Wait for the streaming response to finish sending before allowing the next invocation to be processed. This ensures that responses are sent in the order the invocations were received. + await sendTask; + sendTask.Result.Dispose(); + } + + streamIfCreated.Dispose(); + } + else if (invokeSucceeded) { + // No streaming — send buffered response _logger.LogInformation("Starting sending response"); try { @@ -415,6 +467,7 @@ internal async Task InvokeOnceAsync(CancellationToken cancellationToken = defaul } finally { + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency); invocation.Dispose(); } }; diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/RawStreamingHttpClient.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/RawStreamingHttpClient.cs new file mode 100644 index 000000000..0226e0660 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/RawStreamingHttpClient.cs @@ -0,0 +1,291 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +#if NET8_0_OR_GREATER +using System; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Net.Sockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Amazon.Lambda.RuntimeSupport.Helpers; + +namespace Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming +{ + /// + /// A raw HTTP/1.1 client for sending streaming responses to the Lambda Runtime API + /// with support for HTTP trailing headers (used for error reporting). + /// + /// .NET's HttpClient/SocketsHttpHandler does not support sending HTTP/1.1 trailing headers. + /// The Lambda Runtime API requires error information to be sent as trailing headers + /// (Lambda-Runtime-Function-Error-Type and Lambda-Runtime-Function-Error-Body) after + /// the chunked transfer encoding body. This class gives us full control over the + /// HTTP wire format to properly send those trailers. + /// + internal class RawStreamingHttpClient : IDisposable + { + private readonly string _host; + private readonly int _port; + private TcpClient _tcpClient; + internal Stream _networkStream; + private bool _disposed; + + private readonly InternalLogger _logger = InternalLogger.GetDefaultLogger(); + + public RawStreamingHttpClient(string hostAndPort) + { + var parts = hostAndPort.Split(':'); + _host = parts[0]; + _port = parts.Length > 1 ? int.Parse(parts[1], CultureInfo.InvariantCulture) : 80; + } + + /// + /// Sends a streaming response to the Lambda Runtime API. + /// Connects via TCP, sends HTTP headers, then streams the response body + /// using chunked transfer encoding. When the response stream completes, + /// writes the chunked encoding terminator with optional trailing headers + /// for error reporting. + /// + /// The Lambda request ID. + /// The response stream that provides data and error state. + /// The User-Agent header value. + /// Cancellation token. + public async Task SendStreamingResponseAsync( + string awsRequestId, + ResponseStream responseStream, + string userAgent, + CancellationToken cancellationToken = default) + { + _tcpClient = new TcpClient(); + _tcpClient.NoDelay = true; + await _tcpClient.ConnectAsync(_host, _port, cancellationToken); + _networkStream = _tcpClient.GetStream(); + + // Send HTTP request line and headers + var path = $"/2018-06-01/runtime/invocation/{awsRequestId}/response"; + var headers = new StringBuilder(); + headers.Append($"POST {path} HTTP/1.1\r\n"); + headers.Append($"Host: {_host}:{_port}\r\n"); + headers.Append($"User-Agent: {userAgent}\r\n"); + headers.Append($"Content-Type: application/vnd.awslambda.http-integration-response\r\n"); + headers.Append($"{StreamingConstants.ResponseModeHeader}: {StreamingConstants.StreamingResponseMode}\r\n"); + headers.Append("Transfer-Encoding: chunked\r\n"); + headers.Append($"Trailer: {StreamingConstants.ErrorTypeTrailer}, {StreamingConstants.ErrorBodyTrailer}\r\n"); + headers.Append("\r\n"); + + var headerBytes = Encoding.ASCII.GetBytes(headers.ToString()); + await _networkStream.WriteAsync(headerBytes, cancellationToken); + await _networkStream.FlushAsync(cancellationToken); + + // Hand the network stream (wrapped in a chunked writer) to the ResponseStream + var chunkedWriter = new ChunkedStreamWriter(_networkStream); + await responseStream.SetHttpOutputStreamAsync(chunkedWriter, cancellationToken); + + _logger.LogInformation("In SendStreamingResponseAsync waiting for the underlying Lambda response stream to indicate it is complete."); + + // Wait for the handler to finish writing + await responseStream.WaitForCompletionAsync(cancellationToken); + + // Write the chunked encoding terminator with optional trailers + if (responseStream.HasError) + { + _logger.LogInformation("Adding response stream trailing error headers"); + await WriteTerminatorWithTrailersAsync(responseStream.ReportedError, cancellationToken); + } + else + { + // No error — write simple terminator: 0\r\n\r\n + var terminator = Encoding.ASCII.GetBytes("0\r\n\r\n"); + await _networkStream.WriteAsync(terminator, cancellationToken); + } + + await _networkStream.FlushAsync(cancellationToken); + + // Read and discard the HTTP response (we don't need it, but must consume it) + await ReadAndDiscardResponseAsync(cancellationToken); + } + + /// + /// Writes the chunked encoding terminator with HTTP trailing headers for error reporting. + /// Format: + /// 0\r\n + /// Lambda-Runtime-Function-Error-Type: errorType\r\n + /// Lambda-Runtime-Function-Error-Body: base64EncodedErrorBodyJson\r\n + /// \r\n + /// + /// The error body JSON is Base64-encoded because LambdaJsonExceptionWriter produces + /// pretty-printed multi-line JSON. HTTP trailer values cannot contain raw CR/LF characters + /// as they would break the HTTP framing — the Runtime API would see the first newline + /// inside the JSON as the end of the trailer and treat the rest as malformed data, + /// resulting in Runtime.TruncatedResponse instead of the actual error. + /// + internal async Task WriteTerminatorWithTrailersAsync(Exception exception, CancellationToken cancellationToken) + { + var exceptionInfo = ExceptionInfo.GetExceptionInfo(exception); + var errorBodyJson = LambdaJsonExceptionWriter.WriteJson(exceptionInfo); + var errorBodyBase64 = Convert.ToBase64String(Encoding.UTF8.GetBytes(errorBodyJson)); + + InternalLogger.GetDefaultLogger().LogInformation($"Writing trailing header {StreamingConstants.ErrorTypeTrailer} with error type {exceptionInfo.ErrorType}."); + var trailers = new StringBuilder(); + trailers.Append("0\r\n"); // zero-length chunk (end of body) + trailers.Append($"{StreamingConstants.ErrorTypeTrailer}: {exceptionInfo.ErrorType}\r\n"); + trailers.Append($"{StreamingConstants.ErrorBodyTrailer}: {errorBodyBase64}\r\n"); + trailers.Append("\r\n"); // end of trailers + + var trailerBytes = Encoding.UTF8.GetBytes(trailers.ToString()); + await _networkStream.WriteAsync(trailerBytes, cancellationToken); + } + + /// + /// Reads and discards the HTTP response from the Runtime API. + /// We need to consume the response to properly close the connection, + /// but we don't need to process it. + /// + internal async Task ReadAndDiscardResponseAsync(CancellationToken cancellationToken) + { + var buffer = new byte[4096]; + try + { + // Read until we get the full response. The Runtime API sends a short response. + var totalRead = 0; + var responseText = new StringBuilder(); + while (true) + { + var bytesRead = await _networkStream.ReadAsync(buffer, 0, buffer.Length, cancellationToken); + if (bytesRead == 0) + break; + + totalRead += bytesRead; + responseText.Append(Encoding.ASCII.GetString(buffer, 0, bytesRead)); + + // Check if we've received the complete response (ends with \r\n\r\n for headers, + // or we've read the content-length worth of body) + var text = responseText.ToString(); + if (text.Contains("\r\n\r\n")) + { + // Find Content-Length to know if there's a body to read + var headerEnd = text.IndexOf("\r\n\r\n", StringComparison.Ordinal); + var headers = text.Substring(0, headerEnd); + + var contentLengthMatch = System.Text.RegularExpressions.Regex.Match( + headers, @"Content-Length:\s*(\d+)", System.Text.RegularExpressions.RegexOptions.IgnoreCase); + + if (contentLengthMatch.Success) + { + var contentLength = int.Parse(contentLengthMatch.Groups[1].Value, CultureInfo.InvariantCulture); + var bodyStart = headerEnd + 4; // skip \r\n\r\n + var bodyRead = text.Length - bodyStart; + if (bodyRead >= contentLength) + break; + } + else + { + // No Content-Length, assume response is complete after headers + break; + } + } + + if (totalRead > 16384) + break; // Safety limit + } + } + catch (Exception ex) + { + // Log but don't throw — the streaming response was already sent + _logger.LogDebug($"Error reading Runtime API response: {ex.Message}"); + } + } + + public void Dispose() + { + if (!_disposed) + { + _networkStream?.Dispose(); + _tcpClient?.Dispose(); + _disposed = true; + } + } + } + + /// + /// A write-only Stream wrapper that writes data in HTTP/1.1 chunked transfer encoding format. + /// Each write produces a chunk: {size in hex}\r\n{data}\r\n + /// FlushAsync flushes the underlying network stream to ensure data is sent immediately. + /// The chunked encoding terminator (0\r\n...\r\n) is NOT written by this class — + /// it is handled by RawStreamingHttpClient to support trailing headers. + /// + internal class ChunkedStreamWriter : Stream + { + private readonly Stream _innerStream; + + public ChunkedStreamWriter(Stream innerStream) + { + _innerStream = innerStream ?? throw new ArgumentNullException(nameof(innerStream)); + } + + public override bool CanRead => false; + public override bool CanSeek => false; + public override bool CanWrite => true; + public override long Length => throw new NotSupportedException(); + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + WriteAsync(buffer, offset, count, CancellationToken.None).GetAwaiter().GetResult(); + } + + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (count == 0) return; + + // Write chunk header: size in hex + \r\n + var chunkHeader = Encoding.ASCII.GetBytes($"{count:X}\r\n"); + await _innerStream.WriteAsync(chunkHeader, 0, chunkHeader.Length, cancellationToken); + + // Write chunk data + await _innerStream.WriteAsync(buffer, offset, count, cancellationToken); + + // Write chunk trailer: \r\n + var crlf = Encoding.ASCII.GetBytes("\r\n"); + await _innerStream.WriteAsync(crlf, 0, crlf.Length, cancellationToken); + } + + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + if (buffer.Length == 0) return; + + var chunkHeader = Encoding.ASCII.GetBytes($"{buffer.Length:X}\r\n"); + await _innerStream.WriteAsync(chunkHeader, cancellationToken); + await _innerStream.WriteAsync(buffer, cancellationToken); + await _innerStream.WriteAsync(Encoding.ASCII.GetBytes("\r\n"), cancellationToken); + } + + public override void Flush() => _innerStream.Flush(); + + public override Task FlushAsync(CancellationToken cancellationToken) => + _innerStream.FlushAsync(cancellationToken); + + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + } +} +#endif diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStream.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStream.cs new file mode 100644 index 000000000..8271bf4f1 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStream.cs @@ -0,0 +1,261 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using System; +using System.Buffers; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Amazon.Lambda.RuntimeSupport.Helpers; + +namespace Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming +{ + /// + /// Represents the writable stream used by Lambda handlers to write response data for streaming invocations. + /// + internal class ResponseStream + { + private long _bytesWritten; + private bool _isCompleted; + private bool _hasError; + private Exception _reportedError; + private readonly object _lock = new object(); + + // The live HTTP output stream, set by RawStreamingHttpClient when sending the streaming response. + private Stream _httpOutputStream; + private bool _disposedValue; + + // The wait time is a sanity timeout to avoid waiting indefinitely if SetHttpOutputStreamAsync is not called or takes too long to call. + // Reality is that SetHttpOutputStreamAsync should be called very quickly after CreateStream, so this timeout is generous to avoid false positives but still protects against hanging indefinitely. + private readonly static TimeSpan _httpStreamWaitTimeout = TimeSpan.FromSeconds(30); + + private readonly SemaphoreSlim _httpStreamReady = new SemaphoreSlim(0, 1); + private readonly SemaphoreSlim _completionSignal = new SemaphoreSlim(0, 1); + + private static readonly byte[] PreludeDelimiter = new byte[8]; + + /// + /// The number of bytes written to the Lambda response stream so far. + /// + public long BytesWritten => _bytesWritten; + + /// + /// Gets a value indicating whether an error has occurred. + /// + public bool HasError => _hasError; + + private readonly byte[] _prelude; + + + private readonly InternalLogger _logger; + + + internal Exception ReportedError => _reportedError; + + internal ResponseStream(byte[] prelude) + { + _logger = InternalLogger.GetDefaultLogger(); + _prelude = prelude; + } + + /// + /// Called by RawStreamingHttpClient to provide the HTTP output stream (a ChunkedStreamWriter). + /// + internal async Task SetHttpOutputStreamAsync(Stream httpOutputStream, CancellationToken cancellationToken = default) + { + _httpOutputStream = httpOutputStream; + + // Write the prelude BEFORE releasing _httpStreamReady. This prevents a race + // where a handler WriteAsync that is already waiting on the semaphore could + // sneak in and write body data before the prelude, causing intermittent + // "Failed to parse prelude JSON" errors from API Gateway. + // + // Note: we intentionally do NOT check ThrowIfCompletedOrError() here. + // SetHttpOutputStreamAsync is infrastructure setup called by RawStreamingHttpClient, + // not a handler write. For fast-completing responses (e.g. Results.Json), + // LambdaBootstrap may call MarkCompleted() before the TCP connection is established + // and this method is called. The prelude still needs to be written to the wire + // so the response is properly framed. + if (_prelude?.Length > 0) + { + _logger.LogDebug($"Writing prelude of {_prelude.Length} bytes to HTTP stream."); + + var combinedLength = _prelude.Length + PreludeDelimiter.Length; + var combined = ArrayPool.Shared.Rent(combinedLength); + try + { + Buffer.BlockCopy(_prelude, 0, combined, 0, _prelude.Length); + Buffer.BlockCopy(PreludeDelimiter, 0, combined, _prelude.Length, PreludeDelimiter.Length); + + await _httpOutputStream.WriteAsync(combined, 0, combinedLength, cancellationToken); + await _httpOutputStream.FlushAsync(cancellationToken); + } + finally + { + ArrayPool.Shared.Return(combined); + } + } + + _httpStreamReady.Release(); + } + + /// + /// Called by RawStreamingHttpClient to wait until the handler + /// finishes writing (MarkCompleted or ReportError). + /// + internal async Task WaitForCompletionAsync(CancellationToken cancellationToken = default) + { + await _completionSignal.WaitAsync(cancellationToken); + } + + internal async Task WriteAsync(byte[] buffer, CancellationToken cancellationToken = default) + { + if (buffer == null) + throw new ArgumentNullException(nameof(buffer)); + await WriteAsync(buffer, 0, buffer.Length, cancellationToken); + } + + /// + /// Asynchronously writes a portion of a byte array to the response stream. + /// + /// The byte array containing data to write. + /// The zero-based byte offset in buffer at which to begin copying bytes. + /// The number of bytes to write. + /// Optional cancellation token. + /// A task representing the asynchronous operation. + /// Thrown if the stream is already completed or an error has been reported. + public async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) + { + if (buffer == null) + throw new ArgumentNullException(nameof(buffer)); + if (offset < 0 || offset > buffer.Length) + throw new ArgumentOutOfRangeException(nameof(offset)); + if (count < 0 || offset + count > buffer.Length) + throw new ArgumentOutOfRangeException(nameof(count)); + + // Wait for the HTTP stream to be ready (first write only blocks) + await _httpStreamReady.WaitAsync(_httpStreamWaitTimeout, cancellationToken); + try + { + _logger.LogDebug("Writing chunk to HTTP response stream."); + + lock (_lock) + { + // Only throw on error, not on completed. For buffered ASP.NET Core responses + // (e.g. Results.Json), the pipeline completes and LambdaBootstrap calls + // MarkCompleted() before the pre-start buffer has been flushed to the wire. + // The buffered data still needs to be written even after MarkCompleted. + if (_hasError) + throw new InvalidOperationException("Cannot write to a stream after an error has been reported."); + _bytesWritten += count; + } + + await _httpOutputStream.WriteAsync(buffer, offset, count, cancellationToken); + await _httpOutputStream.FlushAsync(cancellationToken); + } + finally + { + // Re-release so subsequent writes don't block + _httpStreamReady.Release(); + } + } + + /// + /// Reports an error that occurred during streaming. + /// This will send error information via HTTP trailing headers. + /// + /// The exception to report. + /// Thrown if the stream is already completed or an error has already been reported. + internal void ReportError(Exception exception) + { + if (exception == null) + throw new ArgumentNullException(nameof(exception)); + + lock (_lock) + { + if (_isCompleted) + throw new InvalidOperationException("Cannot report an error after the stream has been completed."); + if (_hasError) + throw new InvalidOperationException("An error has already been reported for this stream."); + + _hasError = true; + _reportedError = exception; + _isCompleted = true; + } + // Signal completion so RawStreamingHttpClient can write error trailers and finish + _completionSignal.Release(); + } + + internal void MarkCompleted() + { + bool shouldReleaseLock = false; + lock (_lock) + { + // Release lock if not already completed, otherwise do nothing (idempotent) + if (!_isCompleted) + { + shouldReleaseLock = true; + } + _isCompleted = true; + } + + if (shouldReleaseLock) + { + // Signal completion so RawStreamingHttpClient can write the final chunk and finish + _completionSignal.Release(); + } + } + + private void ThrowIfCompletedOrError() + { + if (_isCompleted) + throw new InvalidOperationException("Cannot write to a completed stream."); + if (_hasError) + throw new InvalidOperationException("Cannot write to a stream after an error has been reported."); + } + + /// + /// Disposes the stream. After calling Dispose, no further writes or error reports should be made. + /// + /// + protected virtual void Dispose(bool disposing) + { + if (!_disposedValue) + { + if (disposing) + { + try { _httpStreamReady.Release(); } catch (SemaphoreFullException) { /* Ignore if already released */ } + _httpStreamReady.Dispose(); + + try { _completionSignal.Release(); } catch (SemaphoreFullException) { /* Ignore if already released */ } + _completionSignal.Dispose(); + } + + _disposedValue = true; + } + } + + /// + /// Dispose of the stream. After calling Dispose, no further writes or error reports should be made. + /// + public void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + } +} diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStreamContext.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStreamContext.cs new file mode 100644 index 000000000..970c43138 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStreamContext.cs @@ -0,0 +1,59 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using System; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; + +namespace Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming +{ + /// + /// Internal context class used by ResponseStreamFactory to track per-invocation streaming state. + /// + internal class ResponseStreamContext + { + /// + /// The AWS request ID for the current invocation. + /// + public string AwsRequestId { get; set; } + + /// + /// Whether CreateStream() has been called for this invocation. + /// + public bool StreamCreated { get; set; } + + /// + /// The ResponseStream instance if created. + /// + public ResponseStream Stream { get; set; } + + /// + /// The RuntimeApiClient used to start the streaming HTTP POST. + /// + public RuntimeApiClient RuntimeApiClient { get; set; } + + /// + /// Cancellation token for the current invocation. + /// + public CancellationToken CancellationToken { get; set; } + + /// + /// The Task representing the in-flight HTTP POST to the Runtime API. + /// Started when CreateStream() is called, completes when the stream is finalized. + /// + public Task SendTask { get; set; } + } +} diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStreamFactory.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStreamFactory.cs new file mode 100644 index 000000000..27b34e8db --- /dev/null +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStreamFactory.cs @@ -0,0 +1,133 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming +{ + /// + /// Factory for creating streaming responses in AWS Lambda functions. + /// Call CreateStream() within your handler to opt into response streaming for that invocation. + /// + internal static class ResponseStreamFactory + { + // For on-demand mode (single invocation at a time) + private static ResponseStreamContext _onDemandContext; + + // For multi-concurrency mode (multiple concurrent invocations) + private static readonly AsyncLocal _asyncLocalContext = new AsyncLocal(); + + /// + /// Creates a streaming response for the current invocation. + /// Can only be called once per invocation. + /// + /// + /// + /// Thrown if called outside an invocation context. + /// Thrown if called more than once per invocation. + public static ResponseStream CreateStream(byte[] prelude) + { +#if NET8_0_OR_GREATER + var context = GetCurrentContext(); + + if (context == null) + { + throw new InvalidOperationException( + "ResponseStreamFactory.CreateStream() can only be called within a Lambda handler invocation."); + } + + if (context.StreamCreated) + { + throw new InvalidOperationException( + "ResponseStreamFactory.CreateStream() can only be called once per invocation."); + } + + var lambdaStream = new ResponseStream(prelude); + context.Stream = lambdaStream; + context.StreamCreated = true; + + // Start the HTTP POST to the Runtime API. + // This runs concurrently — SerializeToStreamAsync will block + // until the handler finishes writing or reports an error. + context.SendTask = context.RuntimeApiClient.StartStreamingResponseAsync( + context.AwsRequestId, lambdaStream, context.CancellationToken); + + return lambdaStream; +#else + throw new NotImplementedException(); +#endif + } + + // Internal methods for LambdaBootstrap to manage state + + internal static void InitializeInvocation( + string awsRequestId, bool isMultiConcurrency, + RuntimeApiClient runtimeApiClient, CancellationToken cancellationToken) + { + var context = new ResponseStreamContext + { + AwsRequestId = awsRequestId, + StreamCreated = false, + Stream = null, + RuntimeApiClient = runtimeApiClient, + CancellationToken = cancellationToken + }; + + if (isMultiConcurrency) + { + _asyncLocalContext.Value = context; + } + else + { + _onDemandContext = context; + } + } + + internal static ResponseStream GetStreamIfCreated(bool isMultiConcurrency) + { + var context = isMultiConcurrency ? _asyncLocalContext.Value : _onDemandContext; + return context?.Stream; + } + + /// + /// Returns the Task for the in-flight HTTP send, or null if streaming wasn't started. + /// LambdaBootstrap awaits this after the handler returns to ensure the HTTP request completes. + /// + internal static Task GetSendTask(bool isMultiConcurrency) + { + var context = isMultiConcurrency ? _asyncLocalContext.Value : _onDemandContext; + return context?.SendTask; + } + + internal static void CleanupInvocation(bool isMultiConcurrency) + { + if (isMultiConcurrency) + { + _asyncLocalContext.Value = null; + } + else + { + _onDemandContext = null; + } + } + + private static ResponseStreamContext GetCurrentContext() + { + // Check multi-concurrency first (AsyncLocal), then on-demand + return _asyncLocalContext.Value ?? _onDemandContext; + } + } +} diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStreamLambdaCoreInitializerIsolated.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStreamLambdaCoreInitializerIsolated.cs new file mode 100644 index 000000000..b86864480 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStreamLambdaCoreInitializerIsolated.cs @@ -0,0 +1,61 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +#if NET8_0_OR_GREATER + +using System; +using System.Threading; +using System.Threading.Tasks; +using Amazon.Lambda.Core.ResponseStreaming; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; +#pragma warning disable CA2252 +namespace Amazon.Lambda.RuntimeSupport +{ + /// + /// This class is used to connect the created by to Amazon.Lambda.Core with it's public interfaces. + /// The deployed Lambda function might be referencing an older version of Amazon.Lambda.Core that does not have the public interfaces for response streaming, + /// so this class is used to avoid a direct dependency on Amazon.Lambda.Core in the rest of the response streaming implementation. + /// + /// Any code referencing this class must wrap the code around a try/catch for to allow for the case where the Lambda function + /// is deployed with an older version of Amazon.Lambda.Core that does not have the response streaming interfaces. + /// + /// + internal class ResponseStreamLambdaCoreInitializerIsolated + { + /// + /// Initalize Amazon.Lambda.Core with a factory method for creating that wraps the internal implementation. + /// + internal static void InitializeCore() + { +#if !ANALYZER_UNIT_TESTS // This precompiler directive is used to avoid the unit tests from needing a dependency on Amazon.Lambda.Core. + Func factory = (byte[] prelude) => new ImplLambdaResponseStream(ResponseStreamFactory.CreateStream(prelude)); + LambdaResponseStreamFactory.SetLambdaResponseStream(factory); +#endif + } + + /// + /// Implements the interface by wrapping a . This is used to connect the internal response streaming implementation to the public interfaces in Amazon.Lambda.Core. + /// + internal class ImplLambdaResponseStream : ILambdaResponseStream + { + private readonly ResponseStream _innerStream; + + internal ImplLambdaResponseStream(ResponseStream innerStream) + { + _innerStream = innerStream; + } + + /// + public long BytesWritten => _innerStream.BytesWritten; + + /// + public bool HasError => _innerStream.HasError; + + /// + public void Dispose() => _innerStream.Dispose(); + + /// + public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) => _innerStream.WriteAsync(buffer, offset, count, cancellationToken); + } + } +} +#endif diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/StreamingConstants.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/StreamingConstants.cs new file mode 100644 index 000000000..43ac607b7 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/StreamingConstants.cs @@ -0,0 +1,43 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +namespace Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming +{ + /// + /// Constants used for Lambda response streaming. + /// + internal static class StreamingConstants + { + /// + /// Header name for Lambda response mode. + /// + public const string ResponseModeHeader = "Lambda-Runtime-Function-Response-Mode"; + + /// + /// Value for streaming response mode. + /// + public const string StreamingResponseMode = "streaming"; + + /// + /// Trailer header name for error type. + /// + public const string ErrorTypeTrailer = "Lambda-Runtime-Function-Error-Type"; + + /// + /// Trailer header name for error body. + /// + public const string ErrorBodyTrailer = "Lambda-Runtime-Function-Error-Body"; + } +} diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Client/RuntimeApiClient.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Client/RuntimeApiClient.cs index daa9fff24..0cddfcd2a 100644 --- a/Libraries/src/Amazon.Lambda.RuntimeSupport/Client/RuntimeApiClient.cs +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Client/RuntimeApiClient.cs @@ -20,6 +20,7 @@ using System.Threading; using System.Threading.Tasks; using Amazon.Lambda.RuntimeSupport.Bootstrap; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; namespace Amazon.Lambda.RuntimeSupport { @@ -177,6 +178,34 @@ public Task ReportRestoreErrorAsync(Exception exception, String errorType = null #endif +#if NET8_0_OR_GREATER + /// + /// Start sending a streaming response to the Lambda Runtime API. + /// Uses a raw TCP connection with chunked transfer encoding to support HTTP/1.1 + /// trailing headers for error reporting, which .NET's HttpClient does not support. + /// The actual data is written by the handler via ResponseStream.WriteAsync, which flows + /// through a ChunkedStreamWriter to the TCP connection. + /// This Task completes when the stream is finalized (MarkCompleted or error). + /// + /// The ID of the function request being responded to. + /// The ResponseStream that will provide the streaming data. + /// The optional cancellation token to use. + /// A Task representing the in-flight HTTP POST. The returned IDisposable is the RawStreamingHttpClient that owns the TCP connection. + internal virtual async Task StartStreamingResponseAsync( + string awsRequestId, ResponseStream responseStream, CancellationToken cancellationToken = default) + { + if (awsRequestId == null) throw new ArgumentNullException(nameof(awsRequestId)); + if (responseStream == null) throw new ArgumentNullException(nameof(responseStream)); + + var userAgent = _httpClient.DefaultRequestHeaders.UserAgent.ToString(); + var rawClient = new RawStreamingHttpClient(LambdaEnvironment.RuntimeServerHostAndPort); + + await rawClient.SendStreamingResponseAsync(awsRequestId, responseStream, userAgent, cancellationToken); + + return rawClient; + } +#endif + /// /// Send a response to a function invocation to the Runtime API as an asynchronous operation. /// diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Helpers/ConsoleLoggerWriter.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Helpers/ConsoleLoggerWriter.cs index 2caa708e3..a2417cbcc 100644 --- a/Libraries/src/Amazon.Lambda.RuntimeSupport/Helpers/ConsoleLoggerWriter.cs +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Helpers/ConsoleLoggerWriter.cs @@ -227,6 +227,7 @@ public LogLevelLoggerWriter(IEnvironmentVariables environmentVariables) /// public LogLevelLoggerWriter(TextWriter stdOutWriter, TextWriter stdErrorWriter) { + _environmentVariables = new SystemEnvironmentVariables(); Initialize(stdOutWriter, stdErrorWriter); } @@ -325,7 +326,7 @@ public IRuntimeApiHeaders CurrentRuntimeApiHeaders { get { - if (Utils.IsUsingMultiConcurrency(_environmentVariables)) + if (_currentRuntimeApiHeadersStorage != null && Utils.IsUsingMultiConcurrency(_environmentVariables)) { return _currentRuntimeApiHeadersStorage.Value; } @@ -333,7 +334,7 @@ public IRuntimeApiHeaders CurrentRuntimeApiHeaders } set { - if (Utils.IsUsingMultiConcurrency(_environmentVariables)) + if (_currentRuntimeApiHeadersStorage != null && Utils.IsUsingMultiConcurrency(_environmentVariables)) { _currentRuntimeApiHeadersStorage.Value = value; } diff --git a/Libraries/test/Amazon.Lambda.Annotations.SourceGenerators.Tests/ALBApiAttributeTests.cs b/Libraries/test/Amazon.Lambda.Annotations.SourceGenerators.Tests/ALBApiAttributeTests.cs new file mode 100644 index 000000000..10cb530b1 --- /dev/null +++ b/Libraries/test/Amazon.Lambda.Annotations.SourceGenerators.Tests/ALBApiAttributeTests.cs @@ -0,0 +1,495 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using Amazon.Lambda.Annotations.ALB; +using System.Linq; +using Xunit; + +namespace Amazon.Lambda.Annotations.SourceGenerators.Tests +{ + public class ALBApiAttributeTests + { + [Fact] + public void Constructor_SetsRequiredProperties() + { + // Arrange & Act + var attr = new ALBApiAttribute( + "arn:aws:elasticloadbalancing:us-east-1:123456789012:listener/app/my-alb/50dc6c495c0c9188/f2f7dc8efc522ab2", + "/api/orders/*", + 10); + + // Assert + Assert.Equal("arn:aws:elasticloadbalancing:us-east-1:123456789012:listener/app/my-alb/50dc6c495c0c9188/f2f7dc8efc522ab2", attr.ListenerArn); + Assert.Equal("/api/orders/*", attr.PathPattern); + Assert.Equal(10, attr.Priority); + } + + [Fact] + public void DefaultValues_AreCorrect() + { + // Arrange & Act + var attr = new ALBApiAttribute("arn:aws:elasticloadbalancing:us-east-1:123456789012:listener/app/my-alb/abc/def", "/hello", 1); + + // Assert + Assert.False(attr.MultiValueHeaders); + Assert.False(attr.IsMultiValueHeadersSet); + Assert.Null(attr.HostHeader); + Assert.Null(attr.HttpMethod); + Assert.Null(attr.ResourceName); + Assert.False(attr.IsResourceNameSet); + } + + [Fact] + public void MultiValueHeaders_WhenExplicitlySet_IsTracked() + { + var attr = new ALBApiAttribute("arn:aws:elasticloadbalancing:us-east-1:123456789012:listener/app/my-alb/abc/def", "/hello", 1); + + // Before setting + Assert.False(attr.IsMultiValueHeadersSet); + + // After setting to false explicitly + attr.MultiValueHeaders = false; + Assert.True(attr.IsMultiValueHeadersSet); + Assert.False(attr.MultiValueHeaders); + + // After setting to true + attr.MultiValueHeaders = true; + Assert.True(attr.IsMultiValueHeadersSet); + Assert.True(attr.MultiValueHeaders); + } + + [Fact] + public void ResourceName_WhenExplicitlySet_IsTracked() + { + var attr = new ALBApiAttribute("arn:aws:elasticloadbalancing:us-east-1:123456789012:listener/app/my-alb/abc/def", "/hello", 1); + + Assert.False(attr.IsResourceNameSet); + + attr.ResourceName = "MyCustomName"; + Assert.True(attr.IsResourceNameSet); + Assert.Equal("MyCustomName", attr.ResourceName); + } + + [Fact] + public void TemplateReference_IsAccepted() + { + var attr = new ALBApiAttribute("@MyALBListener", "/api/*", 5); + + Assert.Equal("@MyALBListener", attr.ListenerArn); + Assert.StartsWith("@", attr.ListenerArn); + } + + [Fact] + public void OptionalProperties_CanBeSet() + { + var attr = new ALBApiAttribute("@MyALBListener", "/api/*", 5) + { + HostHeader = "api.example.com", + HttpMethod = "GET", + MultiValueHeaders = true, + ResourceName = "MyALBTarget" + }; + + Assert.Equal("api.example.com", attr.HostHeader); + Assert.Equal("GET", attr.HttpMethod); + Assert.True(attr.MultiValueHeaders); + Assert.Equal("MyALBTarget", attr.ResourceName); + } + + // ===== Validation Tests ===== + + [Fact] + public void Validate_ValidArn_ReturnsNoErrors() + { + var attr = new ALBApiAttribute( + "arn:aws:elasticloadbalancing:us-east-1:123456789012:listener/app/my-alb/abc/def", + "/api/*", + 1); + + var errors = attr.Validate(); + Assert.Empty(errors); + } + + [Fact] + public void Validate_ValidTemplateReference_ReturnsNoErrors() + { + var attr = new ALBApiAttribute("@MyALBListener", "/api/*", 1); + + var errors = attr.Validate(); + Assert.Empty(errors); + } + + [Fact] + public void Validate_EmptyListenerArn_ReturnsError() + { + var attr = new ALBApiAttribute("", "/api/*", 1); + + var errors = attr.Validate(); + Assert.Single(errors); + Assert.Contains("ListenerArn", errors[0]); + Assert.Contains("required", errors[0]); + } + + [Fact] + public void Validate_NullListenerArn_ReturnsError() + { + var attr = new ALBApiAttribute(null, "/api/*", 1); + + var errors = attr.Validate(); + Assert.Single(errors); + Assert.Contains("ListenerArn", errors[0]); + } + + [Fact] + public void Validate_InvalidListenerArn_NotArnOrReference_ReturnsError() + { + var attr = new ALBApiAttribute("some-random-string", "/api/*", 1); + + var errors = attr.Validate(); + Assert.Single(errors); + Assert.Contains("ListenerArn", errors[0]); + Assert.Contains("arn:", errors[0]); + } + + [Fact] + public void Validate_EmptyPathPattern_ReturnsError() + { + var attr = new ALBApiAttribute("@MyListener", "", 1); + + var errors = attr.Validate(); + Assert.Single(errors); + Assert.Contains("PathPattern", errors[0]); + Assert.Contains("required", errors[0]); + } + + [Fact] + public void Validate_NullPathPattern_ReturnsError() + { + var attr = new ALBApiAttribute("@MyListener", null, 1); + + var errors = attr.Validate(); + Assert.Single(errors); + Assert.Contains("PathPattern", errors[0]); + } + + [Theory] + [InlineData(0)] + [InlineData(-1)] + [InlineData(50001)] + [InlineData(100000)] + public void Validate_InvalidPriority_ReturnsError(int priority) + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", priority); + + var errors = attr.Validate(); + Assert.Single(errors); + Assert.Contains("Priority", errors[0]); + Assert.Contains("1 and 50000", errors[0]); + } + + [Theory] + [InlineData(1)] + [InlineData(50000)] + [InlineData(100)] + [InlineData(25000)] + public void Validate_ValidPriority_ReturnsNoErrors(int priority) + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", priority); + + var errors = attr.Validate(); + Assert.Empty(errors); + } + + [Fact] + public void Validate_InvalidResourceName_ReturnsError() + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", 1) + { + ResourceName = "invalid-name!" + }; + + var errors = attr.Validate(); + Assert.Single(errors); + Assert.Contains("ResourceName", errors[0]); + Assert.Contains("alphanumeric", errors[0]); + } + + [Fact] + public void Validate_ValidResourceName_ReturnsNoErrors() + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", 1) + { + ResourceName = "MyValidResource123" + }; + + var errors = attr.Validate(); + Assert.Empty(errors); + } + + [Fact] + public void Validate_UnsetResourceName_ReturnsNoErrors() + { + // ResourceName not set should not produce validation errors + var attr = new ALBApiAttribute("@MyListener", "/api/*", 1); + + var errors = attr.Validate(); + Assert.Empty(errors); + Assert.False(attr.IsResourceNameSet); + } + + [Theory] + [InlineData("GET")] + [InlineData("POST")] + [InlineData("PUT")] + [InlineData("PATCH")] + [InlineData("DELETE")] + [InlineData("HEAD")] + [InlineData("OPTIONS")] + [InlineData("get")] + [InlineData("post")] + public void Validate_ValidHttpMethod_ReturnsNoErrors(string method) + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", 1) + { + HttpMethod = method + }; + + var errors = attr.Validate(); + Assert.Empty(errors); + } + + [Fact] + public void Validate_InvalidHttpMethod_ReturnsError() + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", 1) + { + HttpMethod = "INVALID" + }; + + var errors = attr.Validate(); + Assert.Single(errors); + Assert.Contains("HttpMethod", errors[0]); + } + + [Fact] + public void Validate_NullHttpMethod_ReturnsNoErrors() + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", 1) + { + HttpMethod = null + }; + + var errors = attr.Validate(); + Assert.Empty(errors); + } + + [Fact] + public void Validate_MultipleErrors_ReturnsAll() + { + var attr = new ALBApiAttribute("", "", 0) + { + ResourceName = "invalid-name!", + HttpMethod = "INVALID" + }; + + var errors = attr.Validate(); + // Should have errors for: ListenerArn, PathPattern, Priority, ResourceName, HttpMethod + Assert.Equal(5, errors.Count); + Assert.Contains(errors, e => e.Contains("ListenerArn")); + Assert.Contains(errors, e => e.Contains("PathPattern")); + Assert.Contains(errors, e => e.Contains("Priority")); + Assert.Contains(errors, e => e.Contains("ResourceName")); + Assert.Contains(errors, e => e.Contains("HttpMethod")); + } + + [Fact] + public void Validate_AllValidWithOptionals_ReturnsNoErrors() + { + var attr = new ALBApiAttribute( + "arn:aws:elasticloadbalancing:us-east-1:123456789012:listener/app/my-alb/abc/def", + "/api/v1/products/*", + 42) + { + MultiValueHeaders = true, + HostHeader = "api.example.com", + HttpMethod = "POST", + ResourceName = "ProductsALB" + }; + + var errors = attr.Validate(); + Assert.Empty(errors); + } + + // ===== HTTP Header Condition Tests ===== + + [Fact] + public void HttpHeaderCondition_DefaultValues_AreNull() + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", 1); + + Assert.Null(attr.HttpHeaderConditionName); + Assert.Null(attr.HttpHeaderConditionValues); + } + + [Fact] + public void HttpHeaderCondition_BothSet_ReturnsNoErrors() + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", 1) + { + HttpHeaderConditionName = "X-Environment", + HttpHeaderConditionValues = new[] { "dev", "staging" } + }; + + var errors = attr.Validate(); + Assert.Empty(errors); + Assert.Equal("X-Environment", attr.HttpHeaderConditionName); + Assert.Equal(2, attr.HttpHeaderConditionValues.Length); + Assert.Equal("dev", attr.HttpHeaderConditionValues[0]); + Assert.Equal("staging", attr.HttpHeaderConditionValues[1]); + } + + [Fact] + public void HttpHeaderCondition_NameSetWithoutValues_ReturnsError() + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", 1) + { + HttpHeaderConditionName = "X-Environment" + }; + + var errors = attr.Validate(); + Assert.Single(errors); + Assert.Contains("HttpHeaderConditionName", errors[0]); + Assert.Contains("HttpHeaderConditionValues", errors[0]); + } + + [Fact] + public void HttpHeaderCondition_ValuesSetWithoutName_ReturnsError() + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", 1) + { + HttpHeaderConditionValues = new[] { "dev" } + }; + + var errors = attr.Validate(); + Assert.Single(errors); + Assert.Contains("HttpHeaderConditionValues", errors[0]); + Assert.Contains("HttpHeaderConditionName", errors[0]); + } + + [Fact] + public void HttpHeaderCondition_NameSetWithEmptyValues_ReturnsError() + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", 1) + { + HttpHeaderConditionName = "User-Agent", + HttpHeaderConditionValues = new string[0] + }; + + var errors = attr.Validate(); + Assert.Single(errors); + Assert.Contains("HttpHeaderConditionName", errors[0]); + } + + [Fact] + public void HttpHeaderCondition_WithWildcards_ReturnsNoErrors() + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", 1) + { + HttpHeaderConditionName = "User-Agent", + HttpHeaderConditionValues = new[] { "*Chrome*", "*Safari*" } + }; + + var errors = attr.Validate(); + Assert.Empty(errors); + } + + // ===== Query String Condition Tests ===== + + [Fact] + public void QueryStringConditions_DefaultValue_IsNull() + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", 1); + Assert.Null(attr.QueryStringConditions); + } + + [Fact] + public void QueryStringConditions_WithKeyValuePairs_ReturnsNoErrors() + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", 1) + { + QueryStringConditions = new[] { "version=v1", "=*example*" } + }; + + var errors = attr.Validate(); + Assert.Empty(errors); + Assert.Equal(2, attr.QueryStringConditions.Length); + Assert.Equal("version=v1", attr.QueryStringConditions[0]); + Assert.Equal("=*example*", attr.QueryStringConditions[1]); + } + + [Fact] + public void QueryStringConditions_WithSingleEntry_ReturnsNoErrors() + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", 1) + { + QueryStringConditions = new[] { "env=prod" } + }; + + var errors = attr.Validate(); + Assert.Empty(errors); + } + + // ===== Source IP Condition Tests ===== + + [Fact] + public void SourceIpConditions_DefaultValue_IsNull() + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", 1); + Assert.Null(attr.SourceIpConditions); + } + + [Fact] + public void SourceIpConditions_WithCidrBlocks_ReturnsNoErrors() + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", 1) + { + SourceIpConditions = new[] { "192.0.2.0/24", "198.51.100.10/32" } + }; + + var errors = attr.Validate(); + Assert.Empty(errors); + Assert.Equal(2, attr.SourceIpConditions.Length); + } + + [Fact] + public void SourceIpConditions_WithIPv6_ReturnsNoErrors() + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", 1) + { + SourceIpConditions = new[] { "2001:db8::/32" } + }; + + var errors = attr.Validate(); + Assert.Empty(errors); + } + + // ===== Combined Condition Tests ===== + + [Fact] + public void AllConditions_CanBeSetTogether_ReturnsNoErrors() + { + var attr = new ALBApiAttribute("@MyListener", "/api/*", 1) + { + HostHeader = "api.example.com", + HttpMethod = "POST", + HttpHeaderConditionName = "X-Environment", + HttpHeaderConditionValues = new[] { "dev" }, + QueryStringConditions = new[] { "version=v1" }, + SourceIpConditions = new[] { "10.0.0.0/8" } + }; + + var errors = attr.Validate(); + Assert.Empty(errors); + } + } +} diff --git a/Libraries/test/Amazon.Lambda.Annotations.SourceGenerators.Tests/ALBApiModelTests.cs b/Libraries/test/Amazon.Lambda.Annotations.SourceGenerators.Tests/ALBApiModelTests.cs new file mode 100644 index 000000000..f45d4a67a --- /dev/null +++ b/Libraries/test/Amazon.Lambda.Annotations.SourceGenerators.Tests/ALBApiModelTests.cs @@ -0,0 +1,272 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using Amazon.Lambda.Annotations.SourceGenerator; +using Amazon.Lambda.Annotations.SourceGenerator.Diagnostics; +using Amazon.Lambda.Annotations.SourceGenerator.Extensions; +using Amazon.Lambda.Annotations.SourceGenerator.Models; +using Amazon.Lambda.Annotations.SourceGenerator.Models.Attributes; +using System.Collections.Generic; +using System.Linq; +using Xunit; + +namespace Amazon.Lambda.Annotations.SourceGenerators.Tests +{ + public class ALBApiModelTests + { + [Fact] + public void TypeFullNames_ContainsALBConstants() + { + Assert.Equal("Amazon.Lambda.ApplicationLoadBalancerEvents.ApplicationLoadBalancerRequest", TypeFullNames.ApplicationLoadBalancerRequest); + Assert.Equal("Amazon.Lambda.ApplicationLoadBalancerEvents.ApplicationLoadBalancerResponse", TypeFullNames.ApplicationLoadBalancerResponse); + Assert.Equal("Amazon.Lambda.Annotations.ALB.ALBApiAttribute", TypeFullNames.ALBApiAttribute); + } + + [Fact] + public void TypeFullNames_Events_ContainsALBApiAttribute() + { + Assert.Contains(TypeFullNames.ALBApiAttribute, TypeFullNames.Events); + } + + [Fact] + public void TypeFullNames_ALBRequests_ContainsLoadBalancerRequest() + { + Assert.Contains(TypeFullNames.ApplicationLoadBalancerRequest, TypeFullNames.ALBRequests); + Assert.Single(TypeFullNames.ALBRequests); + } + + [Fact] + public void EventType_HasALBValue() + { + // Verify the ALB enum value exists + var albEvent = EventType.ALB; + Assert.Equal(EventType.ALB, albEvent); + + // Verify it's distinct from other event types + Assert.NotEqual(EventType.API, albEvent); + Assert.NotEqual(EventType.SQS, albEvent); + } + + [Fact] + public void ALBApiAttributeBuilder_BuildsFromConstructorArgs() + { + // This tests the attribute builder by constructing an ALBApiAttribute directly + // (since we can't easily mock Roslyn AttributeData in unit tests, we test the attribute itself) + var attr = new Annotations.ALB.ALBApiAttribute("@MyListener", "/api/*", 5); + + Assert.Equal("@MyListener", attr.ListenerArn); + Assert.Equal("/api/*", attr.PathPattern); + Assert.Equal(5, attr.Priority); + } + + [Fact] + public void ALBApiAttributeBuilder_BuildsWithAllOptionalProperties() + { + var attr = new Annotations.ALB.ALBApiAttribute("arn:aws:elasticloadbalancing:us-east-1:123456789012:listener/app/my-alb/abc/def", "/api/v1/*", 10) + { + MultiValueHeaders = true, + HostHeader = "api.example.com", + HttpMethod = "POST", + ResourceName = "MyCustomALB" + }; + + Assert.Equal("arn:aws:elasticloadbalancing:us-east-1:123456789012:listener/app/my-alb/abc/def", attr.ListenerArn); + Assert.Equal("/api/v1/*", attr.PathPattern); + Assert.Equal(10, attr.Priority); + Assert.True(attr.MultiValueHeaders); + Assert.True(attr.IsMultiValueHeadersSet); + Assert.Equal("api.example.com", attr.HostHeader); + Assert.Equal("POST", attr.HttpMethod); + Assert.Equal("MyCustomALB", attr.ResourceName); + Assert.True(attr.IsResourceNameSet); + } + + [Fact] + public void LambdaMethodModel_ReturnsApplicationLoadBalancerResponse_WhenDirectReturn() + { + var model = new LambdaMethodModel + { + ReturnsVoid = false, + ReturnsGenericTask = false, + ReturnType = new TypeModel + { + FullName = TypeFullNames.ApplicationLoadBalancerResponse, + TypeArguments = new List() + } + }; + + Assert.True(model.ReturnsApplicationLoadBalancerResponse); + } + + [Fact] + public void LambdaMethodModel_ReturnsApplicationLoadBalancerResponse_WhenTaskReturn() + { + var model = new LambdaMethodModel + { + ReturnsVoid = false, + ReturnsGenericTask = true, + ReturnType = new TypeModel + { + FullName = "System.Threading.Tasks.Task`1", + TypeArguments = new List + { + new TypeModel { FullName = TypeFullNames.ApplicationLoadBalancerResponse } + } + } + }; + + Assert.True(model.ReturnsApplicationLoadBalancerResponse); + } + + [Fact] + public void LambdaMethodModel_ReturnsApplicationLoadBalancerResponse_FalseWhenVoid() + { + var model = new LambdaMethodModel + { + ReturnsVoid = true, + ReturnsGenericTask = false, + ReturnType = new TypeModel + { + FullName = "void", + TypeArguments = new List() + } + }; + + Assert.False(model.ReturnsApplicationLoadBalancerResponse); + } + + [Fact] + public void LambdaMethodModel_ReturnsApplicationLoadBalancerResponse_FalseWhenDifferentType() + { + var model = new LambdaMethodModel + { + ReturnsVoid = false, + ReturnsGenericTask = false, + ReturnType = new TypeModel + { + FullName = "System.String", + TypeArguments = new List() + } + }; + + Assert.False(model.ReturnsApplicationLoadBalancerResponse); + } + + [Fact] + public void ParameterListExtension_ALBRequest_IsNotConvertible() + { + // ApplicationLoadBalancerRequest parameters should be treated as pass-through + var parameters = new List + { + new ParameterModel + { + Name = "request", + Type = new TypeModel { FullName = TypeFullNames.ApplicationLoadBalancerRequest }, + Attributes = new List() + } + }; + + Assert.False(parameters.HasConvertibleParameter()); + } + + [Fact] + public void ParameterListExtension_FromQuery_IsConvertible() + { + // A [FromQuery] string parameter should be convertible + var parameters = new List + { + new ParameterModel + { + Name = "name", + Type = new TypeModel { FullName = "System.String" }, + Attributes = new List + { + new AttributeModel + { + Data = new Annotations.APIGateway.FromQueryAttribute(), + Type = new TypeModel { FullName = TypeFullNames.FromQueryAttribute } + } + } + } + }; + + Assert.True(parameters.HasConvertibleParameter()); + } + + [Fact] + public void ParameterListExtension_ILambdaContext_IsNotConvertible() + { + var parameters = new List + { + new ParameterModel + { + Name = "context", + Type = new TypeModel { FullName = TypeFullNames.ILambdaContext }, + Attributes = new List() + } + }; + + Assert.False(parameters.HasConvertibleParameter()); + } + + [Fact] + public void ParameterListExtension_FromBodyString_IsNotConvertible() + { + // A [FromBody] string parameter should NOT be convertible (string body is pass-through) + var parameters = new List + { + new ParameterModel + { + Name = "body", + Type = new TypeModel { FullName = "string" }, + Attributes = new List + { + new AttributeModel + { + Data = new Annotations.APIGateway.FromBodyAttribute(), + Type = new TypeModel { FullName = TypeFullNames.FromBodyAttribute } + } + } + } + }; + + Assert.False(parameters.HasConvertibleParameter()); + } + + [Fact] + public void DiagnosticDescriptors_FromRouteNotSupportedOnAlb_Exists() + { + Assert.Equal("AWSLambda0134", DiagnosticDescriptors.FromRouteNotSupportedOnAlb.Id); + Assert.Equal(Microsoft.CodeAnalysis.DiagnosticSeverity.Error, DiagnosticDescriptors.FromRouteNotSupportedOnAlb.DefaultSeverity); + } + + [Fact] + public void DiagnosticDescriptors_AlbUnmappedParameter_Exists() + { + Assert.Equal("AWSLambda0135", DiagnosticDescriptors.AlbUnmappedParameter.Id); + Assert.Equal(Microsoft.CodeAnalysis.DiagnosticSeverity.Error, DiagnosticDescriptors.AlbUnmappedParameter.DefaultSeverity); + } + + [Fact] + public void ALBFromQuery_ParameterName_DefaultsToParameterName() + { + // When Name is not set, ALB FromQueryAttribute should default to parameter name + var attr = new Annotations.ALB.FromQueryAttribute(); + Assert.Null(attr.Name); + } + + [Fact] + public void ALBFromQuery_ParameterName_UsesExplicitName() + { + var attr = new Annotations.ALB.FromQueryAttribute { Name = "custom_name" }; + Assert.Equal("custom_name", attr.Name); + } + + [Fact] + public void ALBFromHeader_ParameterName_UsesExplicitName() + { + var attr = new Annotations.ALB.FromHeaderAttribute { Name = "X-Custom-Header" }; + Assert.Equal("X-Custom-Header", attr.Name); + } + } +} diff --git a/Libraries/test/Amazon.Lambda.Annotations.SourceGenerators.Tests/Amazon.Lambda.Annotations.SourceGenerators.Tests.csproj b/Libraries/test/Amazon.Lambda.Annotations.SourceGenerators.Tests/Amazon.Lambda.Annotations.SourceGenerators.Tests.csproj index c8cc6f306..56da7d597 100644 --- a/Libraries/test/Amazon.Lambda.Annotations.SourceGenerators.Tests/Amazon.Lambda.Annotations.SourceGenerators.Tests.csproj +++ b/Libraries/test/Amazon.Lambda.Annotations.SourceGenerators.Tests/Amazon.Lambda.Annotations.SourceGenerators.Tests.csproj @@ -208,6 +208,8 @@ + + - + diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/BaseCustomRuntimeTest.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/BaseCustomRuntimeTest.cs index c220a671e..314aa45c4 100644 --- a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/BaseCustomRuntimeTest.cs +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/BaseCustomRuntimeTest.cs @@ -17,7 +17,7 @@ public class BaseCustomRuntimeTest { public const int FUNCTION_MEMORY_MB = 512; - protected static readonly RegionEndpoint TestRegion = RegionEndpoint.USWest2; + public static readonly RegionEndpoint TestRegion = RegionEndpoint.USWest2; protected static readonly string LAMBDA_ASSUME_ROLE_POLICY = @" { @@ -63,7 +63,7 @@ protected BaseCustomRuntimeTest(IntegrationTestFixture fixture, string functionN /// /// /// - protected async Task CleanUpTestResources(AmazonS3Client s3Client, AmazonLambdaClient lambdaClient, + public async Task CleanUpTestResources(AmazonS3Client s3Client, AmazonLambdaClient lambdaClient, AmazonIdentityManagementServiceClient iamClient, bool roleAlreadyExisted) { await DeleteFunctionIfExistsAsync(lambdaClient); @@ -109,7 +109,7 @@ await iamClient.DetachRolePolicyAsync(new DetachRolePolicyRequest } } - protected async Task PrepareTestResources(IAmazonS3 s3Client, IAmazonLambda lambdaClient, + public async Task PrepareTestResources(IAmazonS3 s3Client, IAmazonLambda lambdaClient, AmazonIdentityManagementServiceClient iamClient) { var roleAlreadyExisted = await ValidateAndSetIamRoleArn(iamClient); @@ -288,7 +288,7 @@ protected async Task CreateFunctionAsync(IAmazonLambda lambdaClient, string buck Handler = Handler, MemorySize = FUNCTION_MEMORY_MB, Timeout = 30, - Runtime = Runtime.Dotnet6, + Runtime = Runtime.Dotnet10, Role = ExecutionRoleArn }; @@ -351,7 +351,16 @@ private string GetDeploymentZipPath() if (!File.Exists(deploymentZipFile)) { - throw new NoDeploymentPackageFoundException(); + var message = new StringBuilder(); + message.AppendLine($"Deployment package for {DeploymentPackageZipRelativePath} not found at expected path: {deploymentZipFile}"); + message.AppendLine("Available Test Bundles:"); + foreach (var kvp in _fixture.TestAppPaths) + { + message.AppendLine($"{kvp.Key}: {kvp.Value}"); + } + + + throw new NoDeploymentPackageFoundException(message.ToString()); } return deploymentZipFile; @@ -380,7 +389,9 @@ private static string FindUp(string path, string fileOrDirectoryName, bool combi protected class NoDeploymentPackageFoundException : Exception { + public NoDeploymentPackageFoundException() { } + public NoDeploymentPackageFoundException(string message) : base(message) { } } private ApplicationLogLevel ConvertRuntimeLogLevel(RuntimeLogLevel runtimeLogLevel) diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/CustomRuntimeTests.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/CustomRuntimeTests.cs index b548d5ba0..8ab008d66 100644 --- a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/CustomRuntimeTests.cs +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/CustomRuntimeTests.cs @@ -48,7 +48,7 @@ public async Task TestAllNET8HandlersAsync() public class CustomRuntimeTests : BaseCustomRuntimeTest { - public enum TargetFramework { NET6, NET8} + public enum TargetFramework { NET8 } private TargetFramework _targetFramework; diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Helpers/CommandLineWrapper.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Helpers/CommandLineWrapper.cs index aa8651eae..ea6fd059e 100644 --- a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Helpers/CommandLineWrapper.cs +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Helpers/CommandLineWrapper.cs @@ -1,5 +1,6 @@ using System; using System.Diagnostics; +using System.Text; using System.Threading; using System.Threading.Tasks; using Xunit; @@ -31,6 +32,7 @@ public static async Task Run(string command, string arguments, string workingDir tcs.TrySetResult(true); }; + var output = new StringBuilder(); try { // Attach event handlers @@ -39,6 +41,7 @@ public static async Task Run(string command, string arguments, string workingDir if (!string.IsNullOrEmpty(args.Data)) { Console.WriteLine(args.Data); + output.Append(args.Data); } }; @@ -47,6 +50,7 @@ public static async Task Run(string command, string arguments, string workingDir if (!string.IsNullOrEmpty(args.Data)) { Console.WriteLine(args.Data); + output.Append(args.Data); } }; @@ -78,6 +82,7 @@ public static async Task Run(string command, string arguments, string workingDir catch (Exception ex) { Console.WriteLine("Exception: " + ex); + Console.WriteLine(output.ToString()); if (!process.HasExited) { process.Kill(); @@ -87,4 +92,4 @@ public static async Task Run(string command, string arguments, string workingDir Assert.True(process.ExitCode == 0, $"Command '{command} {arguments}' failed."); } } -} \ No newline at end of file +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Helpers/LambdaToolsHelper.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Helpers/LambdaToolsHelper.cs index 42a02aac6..154c84f75 100644 --- a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Helpers/LambdaToolsHelper.cs +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Helpers/LambdaToolsHelper.cs @@ -10,6 +10,9 @@ public static class LambdaToolsHelper public static string GetTempTestAppDirectory(string workingDirectory, string testAppPath) { +#if DEBUG + return Path.GetFullPath(Path.Combine(workingDirectory, testAppPath)); +#else var customTestAppPath = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); Directory.CreateDirectory(customTestAppPath); @@ -17,6 +20,7 @@ public static string GetTempTestAppDirectory(string workingDirectory, string tes CopyDirectory(currentDir, customTestAppPath); return Path.Combine(customTestAppPath, testAppPath); +#endif } public static async Task InstallLambdaTools() @@ -78,4 +82,4 @@ private static void CopyDirectory(DirectoryInfo dir, string destDirName) CopyDirectory(subDir, tempPath); } } -} \ No newline at end of file +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/IntegrationTestCollection.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/IntegrationTestCollection.cs index c9ce90e35..6e066eb28 100644 --- a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/IntegrationTestCollection.cs +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/IntegrationTestCollection.cs @@ -2,8 +2,8 @@ namespace Amazon.Lambda.RuntimeSupport.IntegrationTests; -[CollectionDefinition("Integration Tests")] -public class IntegrationTestCollection : ICollectionFixture +[CollectionDefinition("Integration Tests", DisableParallelization = true)] +public class IntegrationTestCollection : ICollectionFixture, ICollectionFixture { -} \ No newline at end of file +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/IntegrationTestFixture.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/IntegrationTestFixture.cs index 89d62d61f..b8c71519e 100644 --- a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/IntegrationTestFixture.cs +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/IntegrationTestFixture.cs @@ -14,10 +14,11 @@ public class IntegrationTestFixture : IAsyncLifetime public async Task InitializeAsync() { + var toolPath = await LambdaToolsHelper.InstallLambdaTools(); + var testAppPath = LambdaToolsHelper.GetTempTestAppDirectory( "../../../../../../..", "Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/CustomRuntimeFunctionTest"); - var toolPath = await LambdaToolsHelper.InstallLambdaTools(); _tempPaths.AddRange([testAppPath, toolPath] ); await LambdaToolsHelper.LambdaPackage(toolPath, "net8.0", testAppPath); TestAppPaths[@"CustomRuntimeFunctionTest\bin\Release\net8.0\CustomRuntimeFunctionTest.zip"] = Path.Combine(testAppPath, @"bin\Release\net8.0\CustomRuntimeFunctionTest.zip"); @@ -25,7 +26,6 @@ public async Task InitializeAsync() testAppPath = LambdaToolsHelper.GetTempTestAppDirectory( "../../../../../../..", "Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/CustomRuntimeAspNetCoreMinimalApiTest"); - toolPath = await LambdaToolsHelper.InstallLambdaTools(); _tempPaths.AddRange([testAppPath, toolPath] ); await LambdaToolsHelper.LambdaPackage(toolPath, "net8.0", testAppPath); TestAppPaths[@"CustomRuntimeAspNetCoreMinimalApiTest\bin\Release\net8.0\CustomRuntimeAspNetCoreMinimalApiTest.zip"] = Path.Combine(testAppPath, @"bin\Release\net8.0\CustomRuntimeAspNetCoreMinimalApiTest.zip"); @@ -33,19 +33,27 @@ public async Task InitializeAsync() testAppPath = LambdaToolsHelper.GetTempTestAppDirectory( "../../../../../../..", "Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/CustomRuntimeAspNetCoreMinimalApiCustomSerializerTest"); - toolPath = await LambdaToolsHelper.InstallLambdaTools(); _tempPaths.AddRange([testAppPath, toolPath] ); await LambdaToolsHelper.LambdaPackage(toolPath, "net8.0", testAppPath); TestAppPaths[@"CustomRuntimeAspNetCoreMinimalApiCustomSerializerTest\bin\Release\net8.0\CustomRuntimeAspNetCoreMinimalApiCustomSerializerTest.zip"] = Path.Combine(testAppPath, @"bin\Release\net8.0\CustomRuntimeAspNetCoreMinimalApiCustomSerializerTest.zip"); + + testAppPath = LambdaToolsHelper.GetTempTestAppDirectory( + "../../../../../../..", + "Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers"); + _tempPaths.AddRange([testAppPath, toolPath]); + await LambdaToolsHelper.LambdaPackage(toolPath, "net10.0", testAppPath); + TestAppPaths[@"ResponseStreamingFunctionHandlers\bin\Release\net10.0\ResponseStreamingFunctionHandlers.zip"] = Path.Combine(testAppPath, "bin", "Release", "net10.0", "ResponseStreamingFunctionHandlers.zip"); } public Task DisposeAsync() { +#if !DEBUG foreach (var tempPath in _tempPaths) { LambdaToolsHelper.CleanUp(tempPath); } +#endif return Task.CompletedTask; } diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/ResponseStreamingTests.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/ResponseStreamingTests.cs new file mode 100644 index 000000000..006df6d15 --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/ResponseStreamingTests.cs @@ -0,0 +1,133 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Amazon.IdentityManagement; +using Amazon.Lambda.Model; +using Amazon.Runtime.EventStreams; +using Amazon.S3; +using Xunit; + +namespace Amazon.Lambda.RuntimeSupport.IntegrationTests +{ + [Collection("Integration Tests")] + public class ResponseStreamingTests : BaseCustomRuntimeTest + { + private readonly static string s_functionName = "IntegTestResponseStreamingFunctionHandlers" + DateTime.Now.Ticks; + + private readonly ResponseStreamingTestsFixture _streamFixture; + + public ResponseStreamingTests(IntegrationTestFixture fixture, ResponseStreamingTestsFixture streamFixture) + : base(fixture, s_functionName, "ResponseStreamingFunctionHandlers.zip", @"ResponseStreamingFunctionHandlers\bin\Release\net10.0\ResponseStreamingFunctionHandlers.zip", "ResponseStreamingFunctionHandlers") + { + _streamFixture = streamFixture; + } + + [Fact] + public async Task SimpleFunctionHandler() + { + await _streamFixture.EnsureResourcesDeployedAsync(this); + + var evnts = await InvokeFunctionAsync(nameof(SimpleFunctionHandler)); + Assert.True(evnts.Any()); + + var content = GetCombinedStreamContent(evnts); + Assert.Equal("Hello, World!", content); + } + + [Fact] + public async Task StreamContentHandler() + { + await _streamFixture.EnsureResourcesDeployedAsync(this); + + var evnts = await InvokeFunctionAsync(nameof(StreamContentHandler)); + Assert.True(evnts.Length > 5); + + var content = GetCombinedStreamContent(evnts); + Assert.Contains("Line 9999", content); + Assert.EndsWith("Finish stream content\n", content); + } + + [Fact] + public async Task UnhandledExceptionHandler() + { + await _streamFixture.EnsureResourcesDeployedAsync(this); + + var evnts = await InvokeFunctionAsync(nameof(UnhandledExceptionHandler)); + Assert.True(evnts.Any()); + + var completeEvent = evnts.Last() as InvokeWithResponseStreamCompleteEvent; + Assert.Equal("InvalidOperationException", completeEvent.ErrorCode); + Assert.Contains("This is an unhandled exception", completeEvent.ErrorDetails); + Assert.Contains("stackTrace", completeEvent.ErrorDetails); + } + + private async Task InvokeFunctionAsync(string handlerScenario) + { + using var client = new AmazonLambdaClient(TestRegion); + + var request = new InvokeWithResponseStreamRequest + { + FunctionName = base.FunctionName, + Payload = new MemoryStream(System.Text.Encoding.UTF8.GetBytes($"\"{handlerScenario}\"")), + InvocationType = ResponseStreamingInvocationType.RequestResponse + }; + + var response = await client.InvokeWithResponseStreamAsync(request); + var evnts = response.EventStream.AsEnumerable().ToArray(); + return evnts; + } + + private string GetCombinedStreamContent(IEventStreamEvent[] events) + { + var sb = new StringBuilder(); + foreach (var evnt in events) + { + if (evnt is InvokeResponseStreamUpdate chunk) + { + var text = System.Text.Encoding.UTF8.GetString(chunk.Payload.ToArray()); + sb.Append(text); + } + } + return sb.ToString(); + } + } + + public class ResponseStreamingTestsFixture : IAsyncLifetime + { + private readonly AmazonLambdaClient _lambdaClient = new AmazonLambdaClient(BaseCustomRuntimeTest.TestRegion); + private readonly AmazonS3Client _s3Client = new AmazonS3Client(BaseCustomRuntimeTest.TestRegion); + private readonly AmazonIdentityManagementServiceClient _iamClient = new AmazonIdentityManagementServiceClient(BaseCustomRuntimeTest.TestRegion); + bool _resourcesCreated; + bool _roleAlreadyExisted; + + ResponseStreamingTests _tests; + + public async Task EnsureResourcesDeployedAsync(ResponseStreamingTests tests) + { + if (_resourcesCreated) + return; + + _tests = tests; + _roleAlreadyExisted = await _tests.PrepareTestResources(_s3Client, _lambdaClient, _iamClient); + + _resourcesCreated = true; + } + + public async Task DisposeAsync() + { + await _tests.CleanUpTestResources(_s3Client, _lambdaClient, _iamClient, _roleAlreadyExisted); + + _lambdaClient.Dispose(); + _s3Client.Dispose(); + _iamClient.Dispose(); + } + + public Task InitializeAsync() => Task.CompletedTask; + } +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/HandlerTests.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/HandlerTests.cs index 80f9d13d0..e71acddcd 100644 --- a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/HandlerTests.cs +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/HandlerTests.cs @@ -31,7 +31,7 @@ namespace Amazon.Lambda.RuntimeSupport.UnitTests { - [Collection("Bootstrap")] + [Collection("ResponseStreamFactory")] public class HandlerTests { private const string AggregateExceptionTestMarker = "AggregateExceptionTesting"; @@ -250,7 +250,7 @@ private async Task TestHandlerFailAsync(string handler, string expect var userCodeLoader = new UserCodeLoader(new SystemEnvironmentVariables(), handler, _internalLogger); var initializer = new UserCodeInitializer(userCodeLoader, _internalLogger); var handlerWrapper = HandlerWrapper.GetHandlerWrapper(userCodeLoader.Invoke); - var bootstrap = new LambdaBootstrap(handlerWrapper, initializer.InitializeAsync) + var bootstrap = new LambdaBootstrap(handlerWrapper.Handler, initializer.InitializeAsync, null, _environmentVariables) { Client = testRuntimeApiClient }; @@ -388,7 +388,9 @@ private async Task ExecHandlerAsync(string handler, string dataIn var userCodeLoader = new UserCodeLoader(new SystemEnvironmentVariables(), handler, _internalLogger); var handlerWrapper = HandlerWrapper.GetHandlerWrapper(userCodeLoader.Invoke); var initializer = new UserCodeInitializer(userCodeLoader, _internalLogger); - var bootstrap = new LambdaBootstrap(handlerWrapper, initializer.InitializeAsync) + // Pass null initializer to bootstrap so RunAsync won't re-invoke Init(), + // which would re-register AssemblyLoad event handlers and re-construct the invoke delegate. + var bootstrap = new LambdaBootstrap(handlerWrapper.Handler, null, null, _environmentVariables) { Client = testRuntimeApiClient }; @@ -403,7 +405,13 @@ private async Task ExecHandlerAsync(string handler, string dataIn Assert.DoesNotContain($"^^[{assertLoggedByInitialize}]^^", actionWriter.ToString()); } - await bootstrap.InitializeAsync(); + await initializer.InitializeAsync(); + + // Re-set logging actions after initialization in case Init's AssemblyLoad event + // handler overwrote them when loading Amazon.Lambda.Core as a handler dependency. + UserCodeLoader.SetCustomerLoggerLogAction(assembly, actionWriter.ToLoggingAction(), _internalLogger); + UserCodeLoader.SetCustomerLoggerLogAction(assembly, actionWriter.ToLoggingWithLevelAction(), _internalLogger); + UserCodeLoader.SetCustomerLoggerLogAction(assembly, actionWriter.ToLoggingWithLevelAndExceptionAction(), _internalLogger); if (assertLoggedByInitialize != null) { diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/LambdaBootstrapTests.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/LambdaBootstrapTests.cs index e1636ff16..76e924ac0 100644 --- a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/LambdaBootstrapTests.cs +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/LambdaBootstrapTests.cs @@ -14,12 +14,14 @@ */ using System; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Net.Http; using System.Text; using System.Threading.Tasks; using Xunit; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; using Amazon.Lambda.RuntimeSupport.Bootstrap; using static Amazon.Lambda.RuntimeSupport.Bootstrap.Constants; @@ -29,6 +31,7 @@ namespace Amazon.Lambda.RuntimeSupport.UnitTests /// Tests to test LambdaBootstrap when it's constructed using its actual constructor. /// Tests of the static GetLambdaBootstrap methods can be found in LambdaBootstrapWrapperTests. /// + [Collection("ResponseStreamFactory")] public class LambdaBootstrapTests { readonly TestHandler _testFunction; @@ -165,7 +168,7 @@ public async Task TraceIdEnvironmentVariableIsSet() [Fact] public async Task HandlerThrowsException() { - using (var bootstrap = new LambdaBootstrap(_testFunction.BaseHandlerThrowsAsync, null)) + using (var bootstrap = new LambdaBootstrap(_testFunction.BaseHandlerThrowsAsync, null, null, _environmentVariables)) { bootstrap.Client = _testRuntimeApiClient; Assert.Null(_environmentVariables.GetEnvironmentVariable(LambdaEnvironment.EnvVarTraceId)); @@ -183,7 +186,7 @@ public async Task HandlerInputAndOutputWork() { const string testInput = "a MiXeD cAsE sTrInG"; - using (var bootstrap = new LambdaBootstrap(_testFunction.BaseHandlerToUpperAsync, null)) + using (var bootstrap = new LambdaBootstrap(_testFunction.BaseHandlerToUpperAsync, null, null, _environmentVariables)) { _testRuntimeApiClient.FunctionInput = Encoding.UTF8.GetBytes(testInput); bootstrap.Client = _testRuntimeApiClient; @@ -201,7 +204,7 @@ public async Task HandlerInputAndOutputWork() [Fact] public async Task HandlerReturnsNull() { - using (var bootstrap = new LambdaBootstrap(_testFunction.BaseHandlerReturnsNullAsync, null)) + using (var bootstrap = new LambdaBootstrap(_testFunction.BaseHandlerReturnsNullAsync, null, null, _environmentVariables)) { _testRuntimeApiClient.FunctionInput = new byte[0]; bootstrap.Client = _testRuntimeApiClient; @@ -283,5 +286,159 @@ public void IsCallPreJitTest() environmentVariables.SetEnvironmentVariable(ENVIRONMENT_VARIABLE_AWS_LAMBDA_INITIALIZATION_TYPE, AWS_LAMBDA_INITIALIZATION_TYPE_PC); Assert.True(UserCodeInit.IsCallPreJit(environmentVariables)); } + + // --- Streaming Integration Tests --- + + private TestStreamingRuntimeApiClient CreateStreamingClient() + { + var envVars = new TestEnvironmentVariables(); + var headers = new Dictionary> + { + { RuntimeApiHeaders.HeaderAwsRequestId, new List { "streaming-request-id" } }, + { RuntimeApiHeaders.HeaderInvokedFunctionArn, new List { "invoked_function_arn" } }, + { RuntimeApiHeaders.HeaderAwsTenantId, new List { "tenant_id" } } + }; + return new TestStreamingRuntimeApiClient(envVars, headers); + } + + /// + /// Property 2: CreateStream Enables Streaming Mode + /// When a handler calls ResponseStreamFactory.CreateStream(), the response is transmitted + /// using streaming mode. LambdaBootstrap awaits the send task. + /// **Validates: Requirements 1.4, 6.1, 6.2, 6.3, 6.4** + /// + [Fact] + public async Task StreamingMode_HandlerCallsCreateStream_SendTaskAwaited() + { + var streamingClient = CreateStreamingClient(); + + LambdaBootstrapHandler handler = async (invocation) => + { + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + await stream.WriteAsync(Encoding.UTF8.GetBytes("hello")); + return new InvocationResponse(Stream.Null, false); + }; + + using (var bootstrap = new LambdaBootstrap(handler, null)) + { + bootstrap.Client = streamingClient; + await bootstrap.InvokeOnceAsync(); + } + + Assert.True(streamingClient.StartStreamingResponseAsyncCalled); + Assert.False(streamingClient.SendResponseAsyncCalled); + } + + /// + /// Property 3: Default Mode Is Buffered + /// When a handler does not call ResponseStreamFactory.CreateStream(), the response + /// is transmitted using buffered mode via SendResponseAsync. + /// **Validates: Requirements 1.5, 7.2** + /// + [Fact] + public async Task BufferedMode_HandlerDoesNotCallCreateStream_UsesSendResponse() + { + var streamingClient = CreateStreamingClient(); + + LambdaBootstrapHandler handler = async (invocation) => + { + var outputStream = new MemoryStream(Encoding.UTF8.GetBytes("buffered response")); + return new InvocationResponse(outputStream); + }; + + using (var bootstrap = new LambdaBootstrap(handler, null)) + { + bootstrap.Client = streamingClient; + await bootstrap.InvokeOnceAsync(); + } + + Assert.False(streamingClient.StartStreamingResponseAsyncCalled); + Assert.True(streamingClient.SendResponseAsyncCalled); + } + + /// + /// Property 14: Exception After Writes Uses Trailers + /// When a handler throws an exception after writing data to an IResponseStream, + /// the error is reported via trailers (ReportErrorAsync) rather than standard error reporting. + /// **Validates: Requirements 5.6, 5.7** + /// + [Fact] + public async Task MidstreamError_ExceptionAfterWrites_ReportsViaTrailers() + { + var streamingClient = CreateStreamingClient(); + + LambdaBootstrapHandler handler = async (invocation) => + { + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + await stream.WriteAsync(Encoding.UTF8.GetBytes("partial data")); + throw new InvalidOperationException("midstream failure"); + }; + + using (var bootstrap = new LambdaBootstrap(handler, null)) + { + bootstrap.Client = streamingClient; + await bootstrap.InvokeOnceAsync(); + } + + // Error should be reported via trailers on the stream, not via standard error reporting + Assert.True(streamingClient.StartStreamingResponseAsyncCalled); + Assert.NotNull(streamingClient.LastStreamingResponseStream); + Assert.True(streamingClient.LastStreamingResponseStream.HasError); + Assert.False(streamingClient.ReportInvocationErrorAsyncExceptionCalled); + } + + /// + /// Property 15: Exception Before CreateStream Uses Standard Error + /// When a handler throws an exception before calling ResponseStreamFactory.CreateStream(), + /// the error is reported using the standard Lambda error reporting mechanism. + /// **Validates: Requirements 5.7, 7.1** + /// + [Fact] + public async Task PreStreamError_ExceptionBeforeCreateStream_UsesStandardErrorReporting() + { + var streamingClient = CreateStreamingClient(); + + LambdaBootstrapHandler handler = async (invocation) => + { + await Task.Yield(); + throw new InvalidOperationException("pre-stream failure"); + }; + + using (var bootstrap = new LambdaBootstrap(handler, null)) + { + bootstrap.Client = streamingClient; + await bootstrap.InvokeOnceAsync(); + } + + Assert.False(streamingClient.StartStreamingResponseAsyncCalled); + Assert.True(streamingClient.ReportInvocationErrorAsyncExceptionCalled); + } + + /// + /// State Isolation: ResponseStreamFactory state is cleared after each invocation. + /// **Validates: Requirements 6.5, 8.9** + /// + [Fact] + public async Task Cleanup_ResponseStreamFactoryStateCleared_AfterInvocation() + { + var streamingClient = CreateStreamingClient(); + + LambdaBootstrapHandler handler = async (invocation) => + { + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + await stream.WriteAsync(Encoding.UTF8.GetBytes("data")); + return new InvocationResponse(Stream.Null, false); + }; + + using (var bootstrap = new LambdaBootstrap(handler, null)) + { + bootstrap.Client = streamingClient; + await bootstrap.InvokeOnceAsync(); + } + + // After invocation, factory state should be cleaned up + Assert.Null(ResponseStreamFactory.GetStreamIfCreated(false)); + Assert.Null(ResponseStreamFactory.GetSendTask(false)); + } } } diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/LambdaResponseStreamingCoreTests.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/LambdaResponseStreamingCoreTests.cs new file mode 100644 index 000000000..0d5c20c86 --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/LambdaResponseStreamingCoreTests.cs @@ -0,0 +1,558 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +#if NET8_0_OR_GREATER +#pragma warning disable CA2252 + +using System; +using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Amazon.Lambda.Core.ResponseStreaming; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; +using Xunit; + +namespace Amazon.Lambda.RuntimeSupport.UnitTests +{ + // ───────────────────────────────────────────────────────────────────────────── + // HttpResponseStreamPrelude.ToByteArray() tests + // ───────────────────────────────────────────────────────────────────────────── + + public class HttpResponseStreamPreludeTests + { + private static JsonDocument ParsePrelude(HttpResponseStreamPrelude prelude) + => JsonDocument.Parse(prelude.ToByteArray()); + + [Fact] + public void ToByteArray_EmptyPrelude_ProducesEmptyJsonObject() + { + var prelude = new HttpResponseStreamPrelude(); + var doc = ParsePrelude(prelude); + + Assert.Equal(JsonValueKind.Object, doc.RootElement.ValueKind); + // No properties should be present + Assert.False(doc.RootElement.TryGetProperty("statusCode", out _)); + Assert.False(doc.RootElement.TryGetProperty("headers", out _)); + Assert.False(doc.RootElement.TryGetProperty("multiValueHeaders", out _)); + Assert.False(doc.RootElement.TryGetProperty("cookies", out _)); + } + + [Fact] + public void ToByteArray_WithStatusCode_IncludesStatusCode() + { + var prelude = new HttpResponseStreamPrelude { StatusCode = HttpStatusCode.OK }; + var doc = ParsePrelude(prelude); + + Assert.True(doc.RootElement.TryGetProperty("statusCode", out var sc)); + Assert.Equal(200, sc.GetInt32()); + } + + [Fact] + public void ToByteArray_WithHeaders_IncludesHeaders() + { + var prelude = new HttpResponseStreamPrelude + { + Headers = new Dictionary + { + ["Content-Type"] = "application/json", + ["X-Custom"] = "value" + } + }; + var doc = ParsePrelude(prelude); + + Assert.True(doc.RootElement.TryGetProperty("headers", out var headers)); + Assert.Equal("application/json", headers.GetProperty("Content-Type").GetString()); + Assert.Equal("value", headers.GetProperty("X-Custom").GetString()); + } + + [Fact] + public void ToByteArray_WithMultiValueHeaders_IncludesMultiValueHeaders() + { + var prelude = new HttpResponseStreamPrelude + { + MultiValueHeaders = new Dictionary> + { + ["Set-Cookie"] = new List { "a=1", "b=2" } + } + }; + var doc = ParsePrelude(prelude); + + Assert.True(doc.RootElement.TryGetProperty("multiValueHeaders", out var mvh)); + var cookies = mvh.GetProperty("Set-Cookie"); + Assert.Equal(JsonValueKind.Array, cookies.ValueKind); + Assert.Equal(2, cookies.GetArrayLength()); + } + + [Fact] + public void ToByteArray_WithCookies_IncludesCookies() + { + var prelude = new HttpResponseStreamPrelude + { + Cookies = new List { "session=abc", "pref=dark" } + }; + var doc = ParsePrelude(prelude); + + Assert.True(doc.RootElement.TryGetProperty("cookies", out var cookies)); + Assert.Equal(JsonValueKind.Array, cookies.ValueKind); + Assert.Equal(2, cookies.GetArrayLength()); + Assert.Equal("session=abc", cookies[0].GetString()); + } + + [Fact] + public void ToByteArray_AllFieldsPopulated_ProducesCorrectJson() + { + var prelude = new HttpResponseStreamPrelude + { + StatusCode = HttpStatusCode.Created, + Headers = new Dictionary { ["X-Req"] = "1" }, + MultiValueHeaders = new Dictionary> { ["X-Multi"] = new List { "a", "b" } }, + Cookies = new List { "c=1" } + }; + var doc = ParsePrelude(prelude); + + Assert.Equal(201, doc.RootElement.GetProperty("statusCode").GetInt32()); + Assert.Equal("1", doc.RootElement.GetProperty("headers").GetProperty("X-Req").GetString()); + Assert.Equal(2, doc.RootElement.GetProperty("multiValueHeaders").GetProperty("X-Multi").GetArrayLength()); + Assert.Equal("c=1", doc.RootElement.GetProperty("cookies")[0].GetString()); + } + + [Fact] + public void ToByteArray_EmptyCollections_OmitsThoseFields() + { + var prelude = new HttpResponseStreamPrelude + { + StatusCode = HttpStatusCode.OK, + Headers = new Dictionary(), // empty — should be omitted + MultiValueHeaders = new Dictionary>(), // empty + Cookies = new List() // empty + }; + var doc = ParsePrelude(prelude); + + Assert.True(doc.RootElement.TryGetProperty("statusCode", out _)); + Assert.False(doc.RootElement.TryGetProperty("headers", out _)); + Assert.False(doc.RootElement.TryGetProperty("multiValueHeaders", out _)); + Assert.False(doc.RootElement.TryGetProperty("cookies", out _)); + } + + [Fact] + public void ToByteArray_ProducesValidUtf8() + { + var prelude = new HttpResponseStreamPrelude + { + StatusCode = HttpStatusCode.OK, + Headers = new Dictionary { ["Content-Type"] = "text/plain; charset=utf-8" } + }; + var bytes = prelude.ToByteArray(); + + // Should not throw + var text = Encoding.UTF8.GetString(bytes); + Assert.NotEmpty(text); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // LambdaResponseStream (Stream subclass) tests + // ───────────────────────────────────────────────────────────────────────────── + + public class LambdaResponseStreamTests + { + /// + /// Creates a LambdaResponseStream backed by a real ResponseStream wired to a MemoryStream. + /// + private static async Task<(LambdaResponseStream lambdaStream, MemoryStream httpOutput)> CreateWiredLambdaStream() + { + var inner = new ResponseStream(Array.Empty()); + var output = new MemoryStream(); + await inner.SetHttpOutputStreamAsync(output); + + var implStream = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var lambdaStream = new LambdaResponseStream(implStream); + return (lambdaStream, output); + } + + [Fact] + public void LambdaResponseStream_IsStreamSubclass() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + Assert.IsAssignableFrom(stream); + } + + [Fact] + public void CanWrite_IsTrue() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + Assert.True(stream.CanWrite); + } + + [Fact] + public void CanRead_IsFalse() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + Assert.False(stream.CanRead); + } + + [Fact] + public void CanSeek_IsFalse() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + Assert.False(stream.CanSeek); + } + + [Fact] + public void Read_ThrowsNotImplementedException() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + Assert.Throws(() => stream.Read(new byte[1], 0, 1)); + } + + [Fact] + public void ReadAsync_ThrowsNotImplementedException() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + // ReadAsync throws synchronously (not async) — capture the thrown task + var ex = Assert.Throws( + () => { var _ = stream.ReadAsync(new byte[1], 0, 1, CancellationToken.None); }); + Assert.NotNull(ex); + } + + [Fact] + public void Seek_ThrowsNotImplementedException() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + Assert.Throws(() => stream.Seek(0, SeekOrigin.Begin)); + } + + [Fact] + public void Position_Get_ThrowsNotSupportedException() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + Assert.Throws(() => _ = stream.Position); + } + + [Fact] + public void Position_Set_ThrowsNotSupportedException() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + Assert.Throws(() => stream.Position = 0); + } + + [Fact] + public void SetLength_ThrowsNotSupportedException() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + Assert.Throws(() => stream.SetLength(100)); + } + + [Fact] + public async Task WriteAsync_WritesRawBytesToHttpStream() + { + var (stream, output) = await CreateWiredLambdaStream(); + var data = Encoding.UTF8.GetBytes("hello streaming"); + + await stream.WriteAsync(data, 0, data.Length); + + Assert.Equal(data, output.ToArray()); + } + + [Fact] + public async Task Write_SyncOverload_WritesRawBytes() + { + var (stream, output) = await CreateWiredLambdaStream(); + var data = new byte[] { 1, 2, 3 }; + + stream.Write(data, 0, data.Length); + + Assert.Equal(data, output.ToArray()); + } + + [Fact] + public async Task Length_ReflectsBytesWritten() + { + var (stream, _) = await CreateWiredLambdaStream(); + var data = new byte[42]; + + await stream.WriteAsync(data, 0, data.Length); + + Assert.Equal(42, stream.Length); + Assert.Equal(42, stream.BytesWritten); + } + + [Fact] + public async Task Flush_IsNoOp() + { + var (stream, _) = await CreateWiredLambdaStream(); + // Should not throw + stream.Flush(); + } + + [Fact] + public async Task WriteAsync_ByteArrayOverload_WritesFullArray() + { + var (stream, output) = await CreateWiredLambdaStream(); + var data = new byte[] { 0xDE, 0xAD, 0xBE, 0xEF }; + + await stream.WriteAsync(data); + + Assert.Equal(data, output.ToArray()); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // ImplLambdaResponseStream (bridge class) tests + // ───────────────────────────────────────────────────────────────────────────── + + public class ImplLambdaResponseStreamTests + { + [Fact] + public async Task WriteAsync_DelegatesToInnerResponseStream() + { + var inner = new ResponseStream(Array.Empty()); + var output = new MemoryStream(); + await inner.SetHttpOutputStreamAsync(output); + + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var data = new byte[] { 1, 2, 3 }; + + await impl.WriteAsync(data, 0, data.Length); + + Assert.Equal(data, output.ToArray()); + } + + [Fact] + public async Task BytesWritten_ReflectsInnerStreamBytesWritten() + { + var inner = new ResponseStream(Array.Empty()); + var output = new MemoryStream(); + await inner.SetHttpOutputStreamAsync(output); + + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + await impl.WriteAsync(new byte[7], 0, 7); + + Assert.Equal(7, impl.BytesWritten); + } + + [Fact] + public void HasError_InitiallyFalse() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + + Assert.False(impl.HasError); + } + + [Fact] + public void HasError_TrueAfterReportError() + { + var inner = new ResponseStream(Array.Empty()); + inner.ReportError(new Exception("test")); + + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + + Assert.True(impl.HasError); + } + + [Fact] + public void Dispose_DisposesInnerStream() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + + // Should not throw + impl.Dispose(); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // LambdaResponseStreamFactory tests + // ───────────────────────────────────────────────────────────────────────────── + + [Collection("ResponseStreamFactory")] + public class LambdaResponseStreamFactoryTests : IDisposable + { + + public LambdaResponseStreamFactoryTests() + { + // Wire up the factory via the initializer (same as production bootstrap does) + ResponseStreamLambdaCoreInitializerIsolated.InitializeCore(); + } + + public void Dispose() + { + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: false); + } + + private void InitializeInvocation(string requestId = "test-req") + { + var envVars = new TestEnvironmentVariables(); + var client = new NoOpStreamingRuntimeApiClient(envVars); + ResponseStreamFactory.InitializeInvocation(requestId, false, client, CancellationToken.None); + } + + /// + /// Minimal RuntimeApiClient that accepts StartStreamingResponseAsync without real HTTP. + /// + private class NoOpStreamingRuntimeApiClient : RuntimeApiClient + { + public NoOpStreamingRuntimeApiClient(IEnvironmentVariables envVars) + : base(envVars, new TestHelpers.NoOpInternalRuntimeApiClient()) { } + + internal override async Task StartStreamingResponseAsync( + string awsRequestId, ResponseStream responseStream, CancellationToken cancellationToken = default) + { + // Provide the HTTP output stream so writes don't block + await responseStream.SetHttpOutputStreamAsync(new MemoryStream(), cancellationToken); + await responseStream.WaitForCompletionAsync(cancellationToken); + return new NoOpDisposable(); + } + } + + [Fact] + public void CreateStream_ReturnsLambdaResponseStream() + { + InitializeInvocation(); + + var stream = LambdaResponseStreamFactory.CreateStream(); + + Assert.NotNull(stream); + Assert.IsType(stream); + } + + [Fact] + public void CreateStream_ReturnsStreamSubclass() + { + InitializeInvocation(); + + var stream = LambdaResponseStreamFactory.CreateStream(); + + Assert.IsAssignableFrom(stream); + } + + [Fact] + public void CreateStream_ReturnedStream_IsWritable() + { + InitializeInvocation(); + + var stream = LambdaResponseStreamFactory.CreateStream(); + + Assert.True(stream.CanWrite); + } + + [Fact] + public void CreateStream_ReturnedStream_IsNotSeekable() + { + InitializeInvocation(); + + var stream = LambdaResponseStreamFactory.CreateStream(); + + Assert.False(stream.CanSeek); + } + + [Fact] + public void CreateStream_ReturnedStream_IsNotReadable() + { + InitializeInvocation(); + + var stream = LambdaResponseStreamFactory.CreateStream(); + + Assert.False(stream.CanRead); + } + + [Fact] + public void CreateHttpStream_WithPrelude_ReturnsLambdaResponseStream() + { + InitializeInvocation(); + + var prelude = new HttpResponseStreamPrelude { StatusCode = HttpStatusCode.OK }; + var stream = LambdaResponseStreamFactory.CreateHttpStream(prelude); + + Assert.NotNull(stream); + Assert.IsType(stream); + } + + [Fact] + public void CreateHttpStream_PassesSerializedPreludeToFactory() + { + // Capture the prelude bytes passed to the inner factory + byte[] capturedPrelude = null; + LambdaResponseStreamFactory.SetLambdaResponseStream(prelude => + { + capturedPrelude = prelude; + // Return a minimal stub that satisfies the interface + return new StubLambdaResponseStream(); + }); + + var httpPrelude = new HttpResponseStreamPrelude + { + StatusCode = HttpStatusCode.Created, + Headers = new Dictionary { ["X-Test"] = "1" } + }; + LambdaResponseStreamFactory.CreateHttpStream(httpPrelude); + + Assert.NotNull(capturedPrelude); + Assert.True(capturedPrelude.Length > 0); + + // Verify the bytes are valid JSON containing the status code + var doc = JsonDocument.Parse(capturedPrelude); + Assert.Equal(201, doc.RootElement.GetProperty("statusCode").GetInt32()); + } + + [Fact] + public void CreateStream_PassesEmptyPreludeToFactory() + { + byte[] capturedPrelude = null; + LambdaResponseStreamFactory.SetLambdaResponseStream(prelude => + { + capturedPrelude = prelude; + return new StubLambdaResponseStream(); + }); + + LambdaResponseStreamFactory.CreateStream(); + + Assert.NotNull(capturedPrelude); + Assert.Empty(capturedPrelude); + } + + private class StubLambdaResponseStream : ILambdaResponseStream + { + public long BytesWritten => 0; + public bool HasError => false; + public void Dispose() { } + public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) + => Task.CompletedTask; + } + } +} +#endif diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/RawStreamingHttpClientTests.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/RawStreamingHttpClientTests.cs new file mode 100644 index 000000000..e203d6968 --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/RawStreamingHttpClientTests.cs @@ -0,0 +1,502 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +#if NET8_0_OR_GREATER + +using System; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; +using Xunit; + +namespace Amazon.Lambda.RuntimeSupport.UnitTests +{ + // ───────────────────────────────────────────────────────────────────────────── + // RawStreamingHttpClient tests + // ───────────────────────────────────────────────────────────────────────────── + + public class RawStreamingHttpClientTests + { + // --- Constructor / host parsing --- + + [Fact] + public void Constructor_HostAndPort_ParsedCorrectly() + { + using var client = new RawStreamingHttpClient("localhost:9001"); + // No exception means parsing succeeded. Fields are private but + // we verify indirectly via Dispose not throwing. + } + + [Fact] + public void Constructor_HostOnly_DefaultsToPort80() + { + using var client = new RawStreamingHttpClient("localhost"); + // Should not throw — defaults port to 80 + } + + [Fact] + public void Constructor_HighPort_ParsedCorrectly() + { + using var client = new RawStreamingHttpClient("127.0.0.1:65535"); + } + + // --- Dispose --- + + [Fact] + public void Dispose_CalledTwice_DoesNotThrow() + { + var client = new RawStreamingHttpClient("localhost:9001"); + client.Dispose(); + client.Dispose(); + } + + [Fact] + public void Dispose_WithoutConnect_DoesNotThrow() + { + var client = new RawStreamingHttpClient("localhost:9001"); + client.Dispose(); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // WriteTerminatorWithTrailersAsync tests + // ───────────────────────────────────────────────────────────────────────────── + + public class WriteTerminatorWithTrailersAsyncTests + { + private static (RawStreamingHttpClient client, MemoryStream output) CreateClientWithMemoryStream() + { + var client = new RawStreamingHttpClient("localhost:9001"); + var output = new MemoryStream(); + client._networkStream = output; + return (client, output); + } + + [Fact] + public async Task WriteTerminator_StartsWithZeroChunk() + { + var (client, output) = CreateClientWithMemoryStream(); + + await client.WriteTerminatorWithTrailersAsync( + new Exception("test"), CancellationToken.None); + + var written = Encoding.UTF8.GetString(output.ToArray()); + Assert.StartsWith("0\r\n", written); + } + + [Fact] + public async Task WriteTerminator_ContainsErrorTypeTrailer() + { + var (client, output) = CreateClientWithMemoryStream(); + + await client.WriteTerminatorWithTrailersAsync( + new InvalidOperationException("bad op"), CancellationToken.None); + + var written = Encoding.UTF8.GetString(output.ToArray()); + Assert.Contains($"{StreamingConstants.ErrorTypeTrailer}: InvalidOperationException\r\n", written); + } + + [Fact] + public async Task WriteTerminator_ContainsErrorBodyTrailerHeader() + { + var (client, output) = CreateClientWithMemoryStream(); + + await client.WriteTerminatorWithTrailersAsync( + new Exception("some error"), CancellationToken.None); + + var written = Encoding.UTF8.GetString(output.ToArray()); + Assert.Contains($"{StreamingConstants.ErrorBodyTrailer}: ", written); + } + + [Fact] + public async Task WriteTerminator_ErrorBodyIsBase64Encoded() + { + var (client, output) = CreateClientWithMemoryStream(); + const string errorMessage = "something broke"; + + await client.WriteTerminatorWithTrailersAsync( + new Exception(errorMessage), CancellationToken.None); + + var written = Encoding.UTF8.GetString(output.ToArray()); + + // Extract the Base64 value from the error body trailer + var prefix = $"{StreamingConstants.ErrorBodyTrailer}: "; + var start = written.IndexOf(prefix, StringComparison.Ordinal) + prefix.Length; + var end = written.IndexOf("\r\n", start, StringComparison.Ordinal); + var base64Value = written.Substring(start, end - start); + + // Should be valid Base64 + var decoded = Encoding.UTF8.GetString(Convert.FromBase64String(base64Value)); + Assert.Contains(errorMessage, decoded); + } + + [Fact] + public async Task WriteTerminator_ErrorBodyBase64ContainsNoNewlines() + { + var (client, output) = CreateClientWithMemoryStream(); + + // Use an exception with a stack trace that would produce multi-line JSON + Exception caughtException; + try { throw new InvalidOperationException("multi\nline\nerror"); } + catch (Exception ex) { caughtException = ex; } + + await client.WriteTerminatorWithTrailersAsync( + caughtException, CancellationToken.None); + + var written = Encoding.UTF8.GetString(output.ToArray()); + + // Extract just the error body trailer line + var prefix = $"{StreamingConstants.ErrorBodyTrailer}: "; + var start = written.IndexOf(prefix, StringComparison.Ordinal) + prefix.Length; + var end = written.IndexOf("\r\n", start, StringComparison.Ordinal); + var base64Value = written.Substring(start, end - start); + + // The Base64 value itself must not contain any newlines + Assert.DoesNotContain("\n", base64Value); + Assert.DoesNotContain("\r", base64Value); + } + + [Fact] + public async Task WriteTerminator_EndsWithEmptyLine() + { + var (client, output) = CreateClientWithMemoryStream(); + + await client.WriteTerminatorWithTrailersAsync( + new Exception("test"), CancellationToken.None); + + var written = Encoding.UTF8.GetString(output.ToArray()); + // Must end with \r\n\r\n — the last trailer line's \r\n plus the empty terminator line + Assert.EndsWith("\r\n\r\n", written); + } + + [Fact] + public async Task WriteTerminator_CorrectWireFormat() + { + var (client, output) = CreateClientWithMemoryStream(); + + await client.WriteTerminatorWithTrailersAsync( + new ArgumentException("bad arg"), CancellationToken.None); + + var written = Encoding.UTF8.GetString(output.ToArray()); + var lines = written.Split("\r\n"); + + // Line 0: "0" (zero-length chunk) + Assert.Equal("0", lines[0]); + // Line 1: error type trailer + Assert.StartsWith($"{StreamingConstants.ErrorTypeTrailer}: ", lines[1]); + // Line 2: error body trailer (Base64) + Assert.StartsWith($"{StreamingConstants.ErrorBodyTrailer}: ", lines[2]); + // Line 3: empty (end of trailers) + Assert.Equal("", lines[3]); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // ReadAndDiscardResponseAsync tests + // ───────────────────────────────────────────────────────────────────────────── + + public class ReadAndDiscardResponseAsyncTests + { + private static (RawStreamingHttpClient client, MemoryStream input) CreateClientWithResponse(string httpResponse) + { + var client = new RawStreamingHttpClient("localhost:9001"); + var input = new MemoryStream(Encoding.ASCII.GetBytes(httpResponse)); + client._networkStream = input; + return (client, input); + } + + [Fact] + public async Task ReadAndDiscard_HeadersOnly_CompletesSuccessfully() + { + var (client, _) = CreateClientWithResponse( + "HTTP/1.1 202 Accepted\r\nContent-Length: 0\r\n\r\n"); + + await client.ReadAndDiscardResponseAsync(CancellationToken.None); + // Should complete without error + } + + [Fact] + public async Task ReadAndDiscard_WithBody_ReadsFullBody() + { + var body = "OK"; + var (client, _) = CreateClientWithResponse( + $"HTTP/1.1 200 OK\r\nContent-Length: {body.Length}\r\n\r\n{body}"); + + await client.ReadAndDiscardResponseAsync(CancellationToken.None); + } + + [Fact] + public async Task ReadAndDiscard_NoContentLength_CompletesAfterHeaders() + { + var (client, _) = CreateClientWithResponse( + "HTTP/1.1 202 Accepted\r\n\r\n"); + + await client.ReadAndDiscardResponseAsync(CancellationToken.None); + } + + [Fact] + public async Task ReadAndDiscard_EmptyStream_CompletesSuccessfully() + { + var client = new RawStreamingHttpClient("localhost:9001"); + client._networkStream = new MemoryStream(Array.Empty()); + + await client.ReadAndDiscardResponseAsync(CancellationToken.None); + } + + [Fact] + public async Task ReadAndDiscard_PartialBody_WaitsForFullBody() + { + // Content-Length says 10 but we provide all 10 bytes + var body = "0123456789"; + var (client, _) = CreateClientWithResponse( + $"HTTP/1.1 200 OK\r\nContent-Length: 10\r\n\r\n{body}"); + + await client.ReadAndDiscardResponseAsync(CancellationToken.None); + } + + [Fact] + public async Task ReadAndDiscard_CancellationToken_Respected() + { + // Use a stream that blocks on read to test cancellation + var cts = new CancellationTokenSource(); + cts.Cancel(); + + var client = new RawStreamingHttpClient("localhost:9001"); + client._networkStream = new MemoryStream(Encoding.ASCII.GetBytes( + "HTTP/1.1 200 OK\r\nContent-Length: 100\r\n\r\n")); + + // Should not throw — ReadAndDiscardResponseAsync catches exceptions + await client.ReadAndDiscardResponseAsync(cts.Token); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // ChunkedStreamWriter tests + // ───────────────────────────────────────────────────────────────────────────── + + public class ChunkedStreamWriterTests + { + [Fact] + public void CanWrite_IsTrue() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + Assert.True(writer.CanWrite); + } + + [Fact] + public void CanRead_IsFalse() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + Assert.False(writer.CanRead); + } + + [Fact] + public void CanSeek_IsFalse() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + Assert.False(writer.CanSeek); + } + + [Fact] + public void Constructor_NullStream_ThrowsArgumentNullException() + { + Assert.Throws(() => new ChunkedStreamWriter(null)); + } + + [Fact] + public void Length_ThrowsNotSupportedException() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + Assert.Throws(() => writer.Length); + } + + [Fact] + public void Position_Get_ThrowsNotSupportedException() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + Assert.Throws(() => writer.Position); + } + + [Fact] + public void Position_Set_ThrowsNotSupportedException() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + Assert.Throws(() => writer.Position = 0); + } + + [Fact] + public void Read_ThrowsNotSupportedException() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + Assert.Throws(() => writer.Read(new byte[1], 0, 1)); + } + + [Fact] + public void Seek_ThrowsNotSupportedException() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + Assert.Throws(() => writer.Seek(0, SeekOrigin.Begin)); + } + + [Fact] + public void SetLength_ThrowsNotSupportedException() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + Assert.Throws(() => writer.SetLength(0)); + } + + [Fact] + public async Task WriteAsync_ByteArray_ProducesCorrectChunkFormat() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + + var data = Encoding.UTF8.GetBytes("Hello"); + await writer.WriteAsync(data, 0, data.Length); + + var output = Encoding.ASCII.GetString(inner.ToArray()); + // "Hello" is 5 bytes = 0x5 + Assert.Equal("5\r\nHello\r\n", output); + } + + [Fact] + public async Task WriteAsync_ReadOnlyMemory_ProducesCorrectChunkFormat() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + + var data = Encoding.UTF8.GetBytes("Hi"); + await writer.WriteAsync(new ReadOnlyMemory(data)); + + var output = Encoding.ASCII.GetString(inner.ToArray()); + Assert.Equal("2\r\nHi\r\n", output); + } + + [Fact] + public async Task WriteAsync_ZeroBytes_WritesNothing() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + + await writer.WriteAsync(Array.Empty(), 0, 0); + + Assert.Equal(0, inner.Length); + } + + [Fact] + public async Task WriteAsync_ReadOnlyMemory_ZeroBytes_WritesNothing() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + + await writer.WriteAsync(ReadOnlyMemory.Empty); + + Assert.Equal(0, inner.Length); + } + + [Fact] + public async Task WriteAsync_MultipleChunks_EachCorrectlyFormatted() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + + await writer.WriteAsync(Encoding.UTF8.GetBytes("AB"), 0, 2); + await writer.WriteAsync(Encoding.UTF8.GetBytes("CDE"), 0, 3); + + var output = Encoding.ASCII.GetString(inner.ToArray()); + Assert.Equal("2\r\nAB\r\n3\r\nCDE\r\n", output); + } + + [Fact] + public async Task WriteAsync_LargeChunk_HexSizeCorrect() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + + var data = new byte[256]; + Array.Fill(data, (byte)'X'); + await writer.WriteAsync(data, 0, data.Length); + + var output = Encoding.ASCII.GetString(inner.ToArray()); + // 256 = 0x100 + Assert.StartsWith("100\r\n", output); + Assert.EndsWith("\r\n", output); + } + + [Fact] + public async Task WriteAsync_WithOffset_WritesCorrectSlice() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + + var data = Encoding.UTF8.GetBytes("ABCDE"); + await writer.WriteAsync(data, 1, 3); // "BCD" + + var output = Encoding.ASCII.GetString(inner.ToArray()); + Assert.Equal("3\r\nBCD\r\n", output); + } + + [Fact] + public void Write_Sync_ProducesCorrectChunkFormat() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + + var data = Encoding.UTF8.GetBytes("OK"); + writer.Write(data, 0, data.Length); + + var output = Encoding.ASCII.GetString(inner.ToArray()); + Assert.Equal("2\r\nOK\r\n", output); + } + + [Fact] + public async Task FlushAsync_DelegatesToInnerStream() + { + var flushCalled = false; + var inner = new FlushTrackingStream(() => flushCalled = true); + using var writer = new ChunkedStreamWriter(inner); + + await writer.FlushAsync(CancellationToken.None); + + Assert.True(flushCalled); + } + + [Fact] + public void Flush_DelegatesToInnerStream() + { + var flushCalled = false; + var inner = new FlushTrackingStream(() => flushCalled = true); + using var writer = new ChunkedStreamWriter(inner); + + writer.Flush(); + + Assert.True(flushCalled); + } + + /// + /// A minimal writable stream that tracks Flush calls. + /// + private class FlushTrackingStream : MemoryStream + { + private readonly Action _onFlush; + public FlushTrackingStream(Action onFlush) => _onFlush = onFlush; + public override void Flush() { _onFlush(); base.Flush(); } + public override Task FlushAsync(CancellationToken cancellationToken) + { + _onFlush(); + return base.FlushAsync(cancellationToken); + } + } + } +} +#endif diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/ResponseStreamFactoryTests.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/ResponseStreamFactoryTests.cs new file mode 100644 index 000000000..cc9a19af2 --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/ResponseStreamFactoryTests.cs @@ -0,0 +1,284 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using System; +using System.Threading; +using System.Threading.Tasks; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; +using Xunit; + +namespace Amazon.Lambda.RuntimeSupport.UnitTests +{ + [Collection("ResponseStreamFactory")] + public class ResponseStreamFactoryTests : IDisposable + { + private const long MaxResponseSize = 20 * 1024 * 1024; + + public void Dispose() + { + // Clean up both modes to avoid test pollution + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: false); + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: true); + } + + /// + /// A minimal RuntimeApiClient subclass for testing that overrides StartStreamingResponseAsync + /// to avoid real HTTP calls while tracking invocations. + /// + private class MockStreamingRuntimeApiClient : RuntimeApiClient + { + public bool StartStreamingCalled { get; private set; } + public string LastAwsRequestId { get; private set; } + public ResponseStream LastResponseStream { get; private set; } + public TaskCompletionSource SendTaskCompletion { get; } = new TaskCompletionSource(); + + public MockStreamingRuntimeApiClient() + : base(new TestEnvironmentVariables(), new TestHelpers.NoOpInternalRuntimeApiClient()) + { + } + + internal override async Task StartStreamingResponseAsync( + string awsRequestId, ResponseStream responseStream, CancellationToken cancellationToken = default) + { + StartStreamingCalled = true; + LastAwsRequestId = awsRequestId; + LastResponseStream = responseStream; + await SendTaskCompletion.Task; + return new NoOpDisposable(); + } + } + + private void InitializeWithMock(string requestId, bool isMultiConcurrency, MockStreamingRuntimeApiClient mockClient) + { + ResponseStreamFactory.InitializeInvocation( + requestId, isMultiConcurrency, + mockClient, CancellationToken.None); + } + + // --- Property 1: CreateStream Returns Valid Stream --- + + /// + /// Property 1: CreateStream Returns Valid Stream - on-demand mode. + /// Validates: Requirements 1.3, 2.2, 2.3 + /// + [Fact] + public void CreateStream_OnDemandMode_ReturnsValidStream() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-1", isMultiConcurrency: false, mock); + + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + + Assert.NotNull(stream); + Assert.IsAssignableFrom(stream); + } + + /// + /// Property 1: CreateStream Returns Valid Stream - multi-concurrency mode. + /// Validates: Requirements 1.3, 2.2, 2.3 + /// + [Fact] + public void CreateStream_MultiConcurrencyMode_ReturnsValidStream() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-2", isMultiConcurrency: true, mock); + + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + + Assert.NotNull(stream); + Assert.IsAssignableFrom(stream); + } + + // --- Property 4: Single Stream Per Invocation --- + + /// + /// Property 4: Single Stream Per Invocation - calling CreateStream twice throws. + /// Validates: Requirements 2.5, 2.6 + /// + [Fact] + public void CreateStream_CalledTwice_ThrowsInvalidOperationException() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-3", isMultiConcurrency: false, mock); + ResponseStreamFactory.CreateStream(Array.Empty()); + + Assert.Throws(() => ResponseStreamFactory.CreateStream(Array.Empty())); + } + + [Fact] + public void CreateStream_OutsideInvocationContext_ThrowsInvalidOperationException() + { + // No InitializeInvocation called + Assert.Throws(() => ResponseStreamFactory.CreateStream(Array.Empty())); + } + + // --- CreateStream starts HTTP POST --- + + /// + /// Validates that CreateStream calls StartStreamingResponseAsync on the RuntimeApiClient. + /// Validates: Requirements 1.3, 1.4, 2.2, 2.3, 2.4 + /// + [Fact] + public void CreateStream_CallsStartStreamingResponseAsync() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-start", isMultiConcurrency: false, mock); + + ResponseStreamFactory.CreateStream(Array.Empty()); + + Assert.True(mock.StartStreamingCalled); + Assert.Equal("req-start", mock.LastAwsRequestId); + Assert.NotNull(mock.LastResponseStream); + } + + // --- GetSendTask --- + + /// + /// Validates that GetSendTask returns the task from the HTTP POST. + /// Validates: Requirements 5.1, 7.3 + /// + [Fact] + public void GetSendTask_AfterCreateStream_ReturnsNonNullTask() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-send", isMultiConcurrency: false, mock); + + ResponseStreamFactory.CreateStream(Array.Empty()); + + var sendTask = ResponseStreamFactory.GetSendTask(isMultiConcurrency: false); + Assert.NotNull(sendTask); + } + + [Fact] + public void GetSendTask_BeforeCreateStream_ReturnsNull() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-nosend", isMultiConcurrency: false, mock); + + var sendTask = ResponseStreamFactory.GetSendTask(isMultiConcurrency: false); + Assert.Null(sendTask); + } + + [Fact] + public void GetSendTask_NoContext_ReturnsNull() + { + Assert.Null(ResponseStreamFactory.GetSendTask(isMultiConcurrency: false)); + } + + // --- Internal methods --- + + [Fact] + public void InitializeInvocation_OnDemand_SetsUpContext() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-4", isMultiConcurrency: false, mock); + + Assert.Null(ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: false)); + + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + Assert.NotNull(stream); + } + + [Fact] + public void InitializeInvocation_MultiConcurrency_SetsUpContext() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-5", isMultiConcurrency: true, mock); + + Assert.Null(ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: true)); + + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + Assert.NotNull(stream); + } + + [Fact] + public void GetStreamIfCreated_AfterCreateStream_ReturnsStream() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-6", isMultiConcurrency: false, mock); + ResponseStreamFactory.CreateStream(Array.Empty()); + + var retrieved = ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: false); + Assert.NotNull(retrieved); + } + + [Fact] + public void GetStreamIfCreated_NoContext_ReturnsNull() + { + Assert.Null(ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: false)); + } + + [Fact] + public void CleanupInvocation_ClearsState() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-7", isMultiConcurrency: false, mock); + ResponseStreamFactory.CreateStream(Array.Empty()); + + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: false); + + Assert.Null(ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: false)); + Assert.Throws(() => ResponseStreamFactory.CreateStream(Array.Empty())); + } + + // --- Property 16: State Isolation Between Invocations --- + + /// + /// Property 16: State Isolation Between Invocations - state from one invocation doesn't leak to the next. + /// Validates: Requirements 6.5, 8.9 + /// + [Fact] + public void StateIsolation_SequentialInvocations_NoLeakage() + { + var mock = new MockStreamingRuntimeApiClient(); + + // First invocation - streaming + InitializeWithMock("req-8a", isMultiConcurrency: false, mock); + var stream1 = ResponseStreamFactory.CreateStream(Array.Empty()); + Assert.NotNull(stream1); + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: false); + + // Second invocation - should start fresh + InitializeWithMock("req-8b", isMultiConcurrency: false, mock); + Assert.Null(ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: false)); + + var stream2 = ResponseStreamFactory.CreateStream(Array.Empty()); + Assert.NotNull(stream2); + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: false); + } + + /// + /// Property 16: State Isolation - multi-concurrency mode uses AsyncLocal. + /// Validates: Requirements 2.9, 2.10 + /// + [Fact] + public async Task StateIsolation_MultiConcurrency_UsesAsyncLocal() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-9", isMultiConcurrency: true, mock); + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + Assert.NotNull(stream); + + bool childSawNull = false; + await Task.Run(() => + { + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: true); + childSawNull = ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: true) == null; + }); + + Assert.True(childSawNull); + } + } +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/ResponseStreamTests.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/ResponseStreamTests.cs new file mode 100644 index 000000000..cd2c00fd2 --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/ResponseStreamTests.cs @@ -0,0 +1,447 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using System; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; +using Xunit; + +namespace Amazon.Lambda.RuntimeSupport.UnitTests +{ + public class ResponseStreamTests + { + /// + /// Helper: creates a ResponseStream and wires up a MemoryStream as the HTTP output stream. + /// Returns both so tests can inspect what was written. + /// + private static async Task<(ResponseStream stream, MemoryStream httpOutput)> CreateWiredStream() + { + var rs = new ResponseStream(Array.Empty()); + var output = new MemoryStream(); + await rs.SetHttpOutputStreamAsync(output); + return (rs, output); + } + + // ---- Basic state tests ---- + + [Fact] + public void Constructor_InitializesStateCorrectly() + { + var stream = new ResponseStream(Array.Empty()); + + Assert.Equal(0, stream.BytesWritten); + Assert.False(stream.HasError); + Assert.Null(stream.ReportedError); + } + + [Fact] + public async Task WriteAsync_WithOffset_WritesCorrectSlice() + { + var (stream, httpOutput) = await CreateWiredStream(); + var data = new byte[] { 0, 1, 2, 3, 0 }; + + await stream.WriteAsync(data, 1, 3); + + // Raw bytes {1,2,3} written directly — no chunked encoding + var expected = new byte[] { 1, 2, 3 }; + Assert.Equal(expected, httpOutput.ToArray()); + } + + [Fact] + public async Task WriteAsync_MultipleWrites_EachAppearsImmediately() + { + var (stream, httpOutput) = await CreateWiredStream(); + + var data = new byte[] { 0xAA }; + await stream.WriteAsync(data, 0, data.Length); + var afterFirst = httpOutput.ToArray().Length; + Assert.True(afterFirst > 0, "First chunk should be on the HTTP stream immediately after WriteAsync returns"); + + await stream.WriteAsync(new byte[] { 0xBB, 0xCC }, 0, 2); + var afterSecond = httpOutput.ToArray().Length; + Assert.True(afterSecond > afterFirst, "Second chunk should appear on the HTTP stream immediately"); + + Assert.Equal(3, stream.BytesWritten); + } + + [Fact] + public async Task WriteAsync_BlocksUntilSetHttpOutputStream() + { + var rs = new ResponseStream(Array.Empty()); + var httpOutput = new MemoryStream(); + var writeStarted = new ManualResetEventSlim(false); + var writeCompleted = new ManualResetEventSlim(false); + + // Start a write on a background thread — it should block + var writeTask = Task.Run(async () => + { + writeStarted.Set(); + await rs.WriteAsync(new byte[] { 1, 2, 3 }, 0, 3); + writeCompleted.Set(); + }); + + // Wait for the write to start, then verify it hasn't completed + writeStarted.Wait(TimeSpan.FromSeconds(2)); + await Task.Delay(100); // give it a moment + Assert.False(writeCompleted.IsSet, "WriteAsync should block until SetHttpOutputStream is called"); + + // Now provide the HTTP stream — the write should complete + await rs.SetHttpOutputStreamAsync(httpOutput); + await writeTask; + + Assert.True(writeCompleted.IsSet); + Assert.True(httpOutput.ToArray().Length > 0); + } + + [Fact] + public async Task MarkCompleted_ReleasesCompletionSignal() + { + var (stream, _) = await CreateWiredStream(); + + var waitTask = stream.WaitForCompletionAsync(); + Assert.False(waitTask.IsCompleted, "WaitForCompletionAsync should block before MarkCompleted"); + + stream.MarkCompleted(); + + // Should complete within a reasonable time + var completed = await Task.WhenAny(waitTask, Task.Delay(TimeSpan.FromSeconds(2))); + Assert.Same(waitTask, completed); + } + + [Fact] + public async Task ReportErrorAsync_ReleasesCompletionSignal() + { + var (stream, _) = await CreateWiredStream(); + + var waitTask = stream.WaitForCompletionAsync(); + Assert.False(waitTask.IsCompleted, "WaitForCompletionAsync should block before ReportErrorAsync"); + + stream.ReportError(new Exception("test error")); + + var completed = await Task.WhenAny(waitTask, Task.Delay(TimeSpan.FromSeconds(2))); + Assert.Same(waitTask, completed); + Assert.True(stream.HasError); + } + + [Fact] + public async Task WriteAsync_AfterMarkCompleted_StillSucceeds() + { + var (stream, output) = await CreateWiredStream(); + await stream.WriteAsync(new byte[] { 1 }, 0, 1); + stream.MarkCompleted(); + + // Writes after MarkCompleted are allowed — buffered ASP.NET Core responses + // (e.g. Results.Json) may flush pre-start buffer data after the pipeline + // completes and LambdaBootstrap calls MarkCompleted. + await stream.WriteAsync(new byte[] { 2 }, 0, 1); + + Assert.Equal(new byte[] { 1, 2 }, output.ToArray()); + } + + [Fact] + public async Task WriteAsync_AfterReportError_Throws() + { + var (stream, _) = await CreateWiredStream(); + await stream.WriteAsync(new byte[] { 1 }, 0, 1); + stream.ReportError(new Exception("test")); + + await Assert.ThrowsAsync( + () => stream.WriteAsync(new byte[] { 2 }, 0, 1)); + } + + [Fact] + public async Task ReportErrorAsync_SetsErrorState() + { + var stream = new ResponseStream(Array.Empty()); + var exception = new InvalidOperationException("something broke"); + + stream.ReportError(exception); + + Assert.True(stream.HasError); + Assert.Same(exception, stream.ReportedError); + } + + [Fact] + public async Task ReportErrorAsync_AfterCompleted_Throws() + { + var stream = new ResponseStream(Array.Empty()); + stream.MarkCompleted(); + + Assert.Throws( + () => stream.ReportError(new Exception("test"))); + } + + [Fact] + public async Task ReportErrorAsync_CalledTwice_Throws() + { + var stream = new ResponseStream(Array.Empty()); + stream.ReportError(new Exception("first")); + + Assert.Throws( + () => stream.ReportError(new Exception("second"))); + } + + [Fact] + public async Task WriteAsync_NullBuffer_ThrowsArgumentNull() + { + var (stream, _) = await CreateWiredStream(); + + await Assert.ThrowsAsync(() => stream.WriteAsync((byte[])null, 0, 0)); + } + + [Fact] + public async Task WriteAsync_NullBufferWithOffset_ThrowsArgumentNull() + { + var (stream, _) = await CreateWiredStream(); + + await Assert.ThrowsAsync(() => stream.WriteAsync(null, 0, 0)); + } + + [Fact] + public async Task ReportErrorAsync_NullException_ThrowsArgumentNull() + { + var stream = new ResponseStream(Array.Empty()); + + Assert.Throws(() => stream.ReportError(null)); + } + + [Fact] + public async Task Dispose_CalledTwice_DoesNotThrow() + { + var stream = new ResponseStream(Array.Empty()); + stream.Dispose(); + // Second dispose should be a no-op + stream.Dispose(); + } + + // ---- Prelude tests ---- + + [Fact] + public async Task SetHttpOutputStreamAsync_WithPrelude_WritesPreludeBeforeHandlerData() + { + var prelude = new byte[] { 0x01, 0x02, 0x03 }; + var rs = new ResponseStream(prelude); + var output = new MemoryStream(); + + await rs.SetHttpOutputStreamAsync(output); + + // Prelude bytes + 8-byte null delimiter should be written before any handler data + var written = output.ToArray(); + Assert.True(written.Length >= prelude.Length + 8, "Prelude + delimiter should be written"); + Assert.Equal(prelude, written[..prelude.Length]); + Assert.Equal(new byte[8], written[prelude.Length..(prelude.Length + 8)]); + } + + [Fact] + public async Task SetHttpOutputStreamAsync_WithEmptyPrelude_WritesNoPreludeBytes() + { + var rs = new ResponseStream(Array.Empty()); + var output = new MemoryStream(); + + await rs.SetHttpOutputStreamAsync(output); + + // Empty prelude — nothing written yet (handler hasn't written anything) + Assert.Empty(output.ToArray()); + } + + [Fact] + public async Task SetHttpOutputStreamAsync_WithPrelude_HandlerDataAppendsAfterDelimiter() + { + var prelude = new byte[] { 0xAA, 0xBB }; + var rs = new ResponseStream(prelude); + var output = new MemoryStream(); + + await rs.SetHttpOutputStreamAsync(output); + await rs.WriteAsync(new byte[] { 0xFF }, 0, 1); + + var written = output.ToArray(); + // Layout: [prelude][8 null bytes][handler data] + int expectedMinLength = prelude.Length + 8 + 1; + Assert.Equal(expectedMinLength, written.Length); + Assert.Equal(new byte[] { 0xFF }, written[^1..]); + } + + [Fact] + public async Task SetHttpOutputStreamAsync_NullPrelude_WritesNoPreludeBytes() + { + var rs = new ResponseStream(null); + var output = new MemoryStream(); + + await rs.SetHttpOutputStreamAsync(output); + + Assert.Empty(output.ToArray()); + } + + // ---- Prelude + delimiter single-chunk tests (via ChunkedStreamWriter) ---- + + [Fact] + public async Task SetHttpOutputStreamAsync_WithPrelude_ViaChunkedWriter_ProducesSingleChunk() + { + var preludeJson = Encoding.UTF8.GetBytes("{\"statusCode\":200}"); + var rs = new ResponseStream(preludeJson); + var rawOutput = new MemoryStream(); + var chunkedWriter = new ChunkedStreamWriter(rawOutput); + + await rs.SetHttpOutputStreamAsync(chunkedWriter); + + var wireBytes = Encoding.ASCII.GetString(rawOutput.ToArray()); + + // The prelude (18 bytes) + delimiter (8 bytes) = 26 bytes = 0x1A + // Should be exactly one chunk: "1A\r\n{prelude}{8 null bytes}\r\n" + var expectedDataLength = preludeJson.Length + 8; // 26 + var expectedHex = expectedDataLength.ToString("X"); + Assert.StartsWith($"{expectedHex}\r\n", wireBytes); + + // Verify there is only one chunk header (only one hex size prefix) + var chunkCount = 0; + var remaining = wireBytes; + while (remaining.Length > 0) + { + var crlfIndex = remaining.IndexOf("\r\n", StringComparison.Ordinal); + if (crlfIndex < 0) break; + var sizeStr = remaining.Substring(0, crlfIndex); + if (int.TryParse(sizeStr, System.Globalization.NumberStyles.HexNumber, null, out var chunkSize) && chunkSize >= 0) + { + chunkCount++; + // Skip past: hex\r\n{data}\r\n + remaining = remaining.Substring(crlfIndex + 2 + chunkSize + 2); + } + else + { + break; + } + } + Assert.Equal(1, chunkCount); + } + + [Fact] + public async Task SetHttpOutputStreamAsync_WithPrelude_ViaChunkedWriter_DelimiterImmediatelyFollowsPrelude() + { + var preludeJson = Encoding.UTF8.GetBytes("{\"statusCode\":201}"); + var rs = new ResponseStream(preludeJson); + var rawOutput = new MemoryStream(); + var chunkedWriter = new ChunkedStreamWriter(rawOutput); + + await rs.SetHttpOutputStreamAsync(chunkedWriter); + + // Parse the chunk to get the raw data payload + var wireBytes = rawOutput.ToArray(); + var wireStr = Encoding.ASCII.GetString(wireBytes); + var firstCrlf = wireStr.IndexOf("\r\n", StringComparison.Ordinal); + var dataStart = firstCrlf + 2; + var dataLength = preludeJson.Length + 8; + var chunkData = new byte[dataLength]; + Array.Copy(wireBytes, dataStart, chunkData, 0, dataLength); + + // First part should be the prelude JSON + Assert.Equal(preludeJson, chunkData[..preludeJson.Length]); + // Immediately followed by 8 null bytes (delimiter) + Assert.Equal(new byte[8], chunkData[preludeJson.Length..]); + } + + [Fact] + public async Task SetHttpOutputStreamAsync_WithPrelude_ViaChunkedWriter_HandlerDataInSeparateChunk() + { + var preludeJson = Encoding.UTF8.GetBytes("{\"statusCode\":200}"); + var rs = new ResponseStream(preludeJson); + var rawOutput = new MemoryStream(); + var chunkedWriter = new ChunkedStreamWriter(rawOutput); + + await rs.SetHttpOutputStreamAsync(chunkedWriter); + await rs.WriteAsync(Encoding.UTF8.GetBytes("body data"), 0, 9); + + var wireStr = Encoding.ASCII.GetString(rawOutput.ToArray()); + + // Should have exactly 2 chunks: one for prelude+delimiter, one for body + var chunkCount = 0; + var remaining = wireStr; + while (remaining.Length > 0) + { + var crlfIndex = remaining.IndexOf("\r\n", StringComparison.Ordinal); + if (crlfIndex < 0) break; + var sizeStr = remaining.Substring(0, crlfIndex); + if (int.TryParse(sizeStr, System.Globalization.NumberStyles.HexNumber, null, out var chunkSize) && chunkSize >= 0) + { + chunkCount++; + remaining = remaining.Substring(crlfIndex + 2 + chunkSize + 2); + } + else + { + break; + } + } + Assert.Equal(2, chunkCount); + } + + // ---- MarkCompleted idempotency ---- + + [Fact] + public async Task MarkCompleted_CalledTwice_DoesNotThrowOrDoubleRelease() + { + var (stream, _) = await CreateWiredStream(); + + stream.MarkCompleted(); + // Second call should be a no-op — semaphore should not be double-released + stream.MarkCompleted(); + + // WaitForCompletionAsync should complete exactly once without hanging + var waitTask = stream.WaitForCompletionAsync(); + var completed = await Task.WhenAny(waitTask, Task.Delay(TimeSpan.FromSeconds(2))); + Assert.Same(waitTask, completed); + } + + [Fact] + public async Task ReportError_ThenMarkCompleted_MarkCompletedIsNoOp() + { + var stream = new ResponseStream(Array.Empty()); + stream.ReportError(new Exception("error")); + + // MarkCompleted after ReportError should not throw and not double-release + stream.MarkCompleted(); + + // WaitForCompletionAsync should complete (released by ReportError) + var waitTask = stream.WaitForCompletionAsync(); + var completed = await Task.WhenAny(waitTask, Task.Delay(TimeSpan.FromSeconds(2))); + Assert.Same(waitTask, completed); + } + + // ---- BytesWritten tracking ---- + + [Fact] + public async Task BytesWritten_TracksAcrossMultipleWrites() + { + var (stream, _) = await CreateWiredStream(); + + await stream.WriteAsync(new byte[10], 0, 10); + await stream.WriteAsync(new byte[5], 0, 5); + + Assert.Equal(15, stream.BytesWritten); + } + + [Fact] + public async Task BytesWritten_ReflectsOffsetAndCount() + { + var (stream, _) = await CreateWiredStream(); + + await stream.WriteAsync(new byte[10], 2, 6); // only 6 bytes + + Assert.Equal(6, stream.BytesWritten); + } + } +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/RuntimeApiClientTests.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/RuntimeApiClientTests.cs new file mode 100644 index 000000000..71102ddf1 --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/RuntimeApiClientTests.cs @@ -0,0 +1,211 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using System; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; +using Xunit; + +namespace Amazon.Lambda.RuntimeSupport.UnitTests +{ + /// + /// Tests for RuntimeApiClient streaming and buffered behavior. + /// Validates Properties 7, 8, 10, 13, 18. + /// + public class RuntimeApiClientTests + { + private const long MaxResponseSize = 20 * 1024 * 1024; + + /// + /// Mock HttpMessageHandler that captures the request for header inspection. + /// It completes the ResponseStream and returns immediately without reading + /// the content body, avoiding the SerializeToStreamAsync blocking issue. + /// + private class MockHttpMessageHandler : HttpMessageHandler + { + public HttpRequestMessage CapturedRequest { get; private set; } + private readonly ResponseStream _responseStream; + + public MockHttpMessageHandler(ResponseStream responseStream) + { + _responseStream = responseStream; + } + + protected override Task SendAsync( + HttpRequestMessage request, CancellationToken cancellationToken) + { + CapturedRequest = request; + + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK)); + } + } + + private static RuntimeApiClient CreateClientWithMockHandler( + ResponseStream stream, out MockHttpMessageHandler handler) + { + handler = new MockHttpMessageHandler(stream); + var httpClient = new HttpClient(handler); + var envVars = new TestEnvironmentVariables(); + envVars.SetEnvironmentVariable("AWS_LAMBDA_RUNTIME_API", "localhost:9001"); + return new RuntimeApiClient(envVars, httpClient); + } + + // --- Property 7: Streaming Response Mode Header --- + // Note: Properties 7, 8, 13 test the HttpClient-based streaming path which is only used on pre-NET8 targets. + // On NET8+, StartStreamingResponseAsync uses RawStreamingHttpClient (raw TCP) which doesn't go through HttpClient. + +#if !NET8_0_OR_GREATER + /// + /// Property 7: Streaming Response Mode Header + /// For any streaming response, the HTTP request should include + /// "Lambda-Runtime-Function-Response-Mode: streaming". + /// **Validates: Requirements 4.1** + /// + [Fact] + public async Task StartStreamingResponseAsync_IncludesStreamingResponseModeHeader() + { + var stream = new ResponseStream(Array.Empty()); + var client = CreateClientWithMockHandler(stream, out var handler); + + await client.StartStreamingResponseAsync("req-1", stream, CancellationToken.None); + + Assert.NotNull(handler.CapturedRequest); + Assert.True(handler.CapturedRequest.Headers.Contains(StreamingConstants.ResponseModeHeader)); + var values = handler.CapturedRequest.Headers.GetValues(StreamingConstants.ResponseModeHeader).ToList(); + Assert.Single(values); + Assert.Equal(StreamingConstants.StreamingResponseMode, values[0]); + } + + // --- Property 8: Chunked Transfer Encoding Header --- + + /// + /// Property 8: Chunked Transfer Encoding Header + /// For any streaming response, the HTTP request should include + /// "Transfer-Encoding: chunked". + /// **Validates: Requirements 4.2** + /// + [Fact] + public async Task StartStreamingResponseAsync_IncludesChunkedTransferEncodingHeader() + { + var stream = new ResponseStream(Array.Empty()); + var client = CreateClientWithMockHandler(stream, out var handler); + + await client.StartStreamingResponseAsync("req-2", stream, CancellationToken.None); + + Assert.NotNull(handler.CapturedRequest); + Assert.True(handler.CapturedRequest.Headers.TransferEncodingChunked); + } + + // --- Property 13: Trailer Declaration Header --- + + /// + /// Property 13: Trailer Declaration Header + /// For any streaming response, the HTTP request should include a "Trailer" header + /// declaring the error trailer headers upfront (since we cannot know at request + /// start whether an error will occur). + /// **Validates: Requirements 5.4** + /// + [Fact] + public async Task StartStreamingResponseAsync_DeclaresTrailerHeaderUpfront() + { + var stream = new ResponseStream(Array.Empty()); + var client = CreateClientWithMockHandler(stream, out var handler); + + await client.StartStreamingResponseAsync("req-3", stream, CancellationToken.None); + + Assert.NotNull(handler.CapturedRequest); + Assert.True(handler.CapturedRequest.Headers.Contains("Trailer")); + var trailerValue = string.Join(", ", handler.CapturedRequest.Headers.GetValues("Trailer")); + Assert.Contains(StreamingConstants.ErrorTypeTrailer, trailerValue); + Assert.Contains(StreamingConstants.ErrorBodyTrailer, trailerValue); + } +#endif + + // --- Property 10: Buffered Responses Exclude Streaming Headers --- + + /// + /// Mock HttpMessageHandler that captures the request for buffered response header inspection. + /// Returns an Accepted (202) response since that's what the InternalRuntimeApiClient expects. + /// + private class BufferedMockHttpMessageHandler : HttpMessageHandler + { + public HttpRequestMessage CapturedRequest { get; private set; } + + protected override Task SendAsync( + HttpRequestMessage request, CancellationToken cancellationToken) + { + CapturedRequest = request; + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.Accepted)); + } + } + + /// + /// Property 10: Buffered Responses Exclude Streaming Headers + /// For any buffered response (where CreateStream was not called), the HTTP request + /// should not include "Lambda-Runtime-Function-Response-Mode" or + /// "Transfer-Encoding: chunked" or "Trailer" headers. + /// **Validates: Requirements 4.6** + /// + [Fact] + public async Task SendResponseAsync_BufferedResponse_ExcludesStreamingHeaders() + { + var bufferedHandler = new BufferedMockHttpMessageHandler(); + var httpClient = new HttpClient(bufferedHandler); + var envVars = new TestEnvironmentVariables(); + envVars.SetEnvironmentVariable("AWS_LAMBDA_RUNTIME_API", "localhost:9001"); + var client = new RuntimeApiClient(envVars, httpClient); + + var outputStream = new MemoryStream(new byte[] { 1, 2, 3 }); + await client.SendResponseAsync("req-buffered", outputStream, CancellationToken.None); + + Assert.NotNull(bufferedHandler.CapturedRequest); + // Buffered responses must not include streaming-specific headers + Assert.False(bufferedHandler.CapturedRequest.Headers.Contains(StreamingConstants.ResponseModeHeader), + "Buffered response should not include Lambda-Runtime-Function-Response-Mode header"); + Assert.NotEqual(true, bufferedHandler.CapturedRequest.Headers.TransferEncodingChunked); + Assert.False(bufferedHandler.CapturedRequest.Headers.Contains("Trailer"), + "Buffered response should not include Trailer header"); + } + + // --- Argument validation --- + +#if NET8_0_OR_GREATER + [Fact] + public async Task StartStreamingResponseAsync_NullRequestId_ThrowsArgumentNullException() + { + var stream = new ResponseStream(Array.Empty()); + var client = CreateClientWithMockHandler(stream, out _); + + await Assert.ThrowsAsync( + () => client.StartStreamingResponseAsync(null, stream, CancellationToken.None)); + } + + [Fact] + public async Task StartStreamingResponseAsync_NullResponseStream_ThrowsArgumentNullException() + { + var stream = new ResponseStream(Array.Empty()); + var client = CreateClientWithMockHandler(stream, out _); + + await Assert.ThrowsAsync( + () => client.StartStreamingResponseAsync("req-5", null, CancellationToken.None)); + } +#endif + } +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/StreamingE2EWithMoq.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/StreamingE2EWithMoq.cs new file mode 100644 index 000000000..f46c76f13 --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/StreamingE2EWithMoq.cs @@ -0,0 +1,545 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; +using Amazon.Lambda.RuntimeSupport.UnitTests.TestHelpers; +using Xunit; + +namespace Amazon.Lambda.RuntimeSupport.UnitTests +{ + [CollectionDefinition("ResponseStreamFactory")] + public class ResponseStreamFactoryCollection { } + + /// + /// End-to-end integration tests for the true-streaming architecture. + /// These tests exercise the full pipeline: LambdaBootstrap → ResponseStreamFactory → + /// ResponseStream → captured HTTP output stream. + /// + [Collection("ResponseStreamFactory")] + public class StreamingE2EWithMoq : IDisposable + { + public void Dispose() + { + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: false); + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: true); + } + + // ─── Helpers ──────────────────────────────────────────────────────────────── + + private static Dictionary> MakeHeaders(string requestId = "test-request-id") + => new Dictionary> + { + { RuntimeApiHeaders.HeaderAwsRequestId, new List { requestId } }, + { RuntimeApiHeaders.HeaderInvokedFunctionArn, new List { "arn:aws:lambda:us-east-1:123456789012:function:test" } }, + { RuntimeApiHeaders.HeaderAwsTenantId, new List { "tenant-id" } }, + { RuntimeApiHeaders.HeaderTraceId, new List { "trace-id" } }, + { RuntimeApiHeaders.HeaderDeadlineMs, new List { "9999999999999" } }, + }; + + /// + /// A capturing RuntimeApiClient that records the raw bytes written to the HTTP output stream + /// by SerializeToStreamAsync. + /// + private class CapturingStreamingRuntimeApiClient : RuntimeApiClient, IRuntimeApiClient + { + private readonly IEnvironmentVariables _envVars; + private readonly Dictionary> _headers; + + public bool StartStreamingCalled { get; private set; } + public bool SendResponseCalled { get; private set; } + public bool ReportInvocationErrorCalled { get; private set; } + public byte[] CapturedHttpBytes { get; private set; } + public ResponseStream LastResponseStream { get; private set; } + public Stream LastBufferedOutputStream { get; private set; } + + public new Amazon.Lambda.RuntimeSupport.Helpers.IConsoleLoggerWriter ConsoleLogger { get; } = new Helpers.LogLevelLoggerWriter(new SystemEnvironmentVariables()); + + public CapturingStreamingRuntimeApiClient( + IEnvironmentVariables envVars, + Dictionary> headers) + : base(envVars, new NoOpInternalRuntimeApiClient()) + { + _envVars = envVars; + _headers = headers; + } + + public new async Task GetNextInvocationAsync(CancellationToken cancellationToken = default) + { + _headers[RuntimeApiHeaders.HeaderTraceId] = new List { Guid.NewGuid().ToString() }; + var inputStream = new MemoryStream(new byte[0]); + return new InvocationRequest + { + InputStream = inputStream, + LambdaContext = new LambdaContext( + new RuntimeApiHeaders(_headers), + new LambdaEnvironment(_envVars), + new TestDateTimeHelper(), + new Helpers.SimpleLoggerWriter(_envVars)) + }; + } + + internal override async Task StartStreamingResponseAsync( + string awsRequestId, ResponseStream responseStream, CancellationToken cancellationToken = default) + { + StartStreamingCalled = true; + LastResponseStream = responseStream; + + // Use a real MemoryStream as the HTTP output stream so we capture actual bytes + var captureStream = new MemoryStream(); + await responseStream.SetHttpOutputStreamAsync(captureStream, cancellationToken); + + // Wait for the handler to finish writing (mirrors real RawStreamingHttpClient behavior) + await responseStream.WaitForCompletionAsync(cancellationToken); + CapturedHttpBytes = captureStream.ToArray(); + return new NoOpDisposable(); + } + + public new async Task SendResponseAsync(string awsRequestId, Stream outputStream, CancellationToken cancellationToken = default) + { + SendResponseCalled = true; + if (outputStream != null) + { + var ms = new MemoryStream(); + await outputStream.CopyToAsync(ms); + ms.Position = 0; + LastBufferedOutputStream = ms; + } + } + + public new Task ReportInvocationErrorAsync(string awsRequestId, Exception exception, CancellationToken cancellationToken = default) + { + ReportInvocationErrorCalled = true; + return Task.CompletedTask; + } + + public new Task ReportInitializationErrorAsync(Exception exception, string errorType = null, CancellationToken cancellationToken = default) + => Task.CompletedTask; + + public new Task ReportInitializationErrorAsync(string errorType, CancellationToken cancellationToken = default) + => Task.CompletedTask; + +#if NET8_0_OR_GREATER + public new Task RestoreNextInvocationAsync(CancellationToken cancellationToken = default) => Task.CompletedTask; + public new Task ReportRestoreErrorAsync(Exception exception, string errorType = null, CancellationToken cancellationToken = default) => Task.CompletedTask; +#endif + } + + private static CapturingStreamingRuntimeApiClient CreateClient(string requestId = "test-request-id") + => new CapturingStreamingRuntimeApiClient(new TestEnvironmentVariables(), MakeHeaders(requestId)); + + /// + /// End-to-end: all data is transmitted correctly (content round-trip). + /// Requirements: 3.2, 4.3, 10.1 + /// + [Fact] + public async Task Streaming_AllDataTransmitted_ContentRoundTrip() + { + var client = CreateClient(); + var payload = Encoding.UTF8.GetBytes("integration test payload"); + + LambdaBootstrapHandler handler = async (invocation) => + { + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + await stream.WriteAsync(payload); + return new InvocationResponse(Stream.Null, false); + }; + + using var bootstrap = new LambdaBootstrap(handler, null); + bootstrap.Client = client; + await bootstrap.InvokeOnceAsync(); + + var output = client.CapturedHttpBytes; + Assert.NotNull(output); + + var outputStr = Encoding.UTF8.GetString(output); + Assert.Contains("integration test payload", outputStr); + } + + /// + /// End-to-end: stream is finalized (final chunk written, BytesWritten matches). + /// Requirements: 3.2, 4.3, 10.1 + /// + [Fact] + public async Task Streaming_StreamFinalized_BytesWrittenMatchesPayload() + { + var client = CreateClient(); + var data = Encoding.UTF8.GetBytes("finalization check"); + + LambdaBootstrapHandler handler = async (invocation) => + { + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + await stream.WriteAsync(data); + return new InvocationResponse(Stream.Null, false); + }; + + using var bootstrap = new LambdaBootstrap(handler, null); + bootstrap.Client = client; + await bootstrap.InvokeOnceAsync(); + + Assert.NotNull(client.LastResponseStream); + Assert.Equal(data.Length, client.LastResponseStream.BytesWritten); + } + + // ─── 10.2 End-to-end buffered response ────────────────────────────────────── + + /// + /// End-to-end: handler does NOT call CreateStream — response goes via buffered path. + /// Verifies SendResponseAsync is called and streaming headers are absent. + /// Requirements: 1.5, 4.6, 9.4 + /// + [Fact] + public async Task Buffered_HandlerDoesNotCallCreateStream_UsesSendResponsePath() + { + var client = CreateClient(); + var responseBody = Encoding.UTF8.GetBytes("buffered response body"); + + LambdaBootstrapHandler handler = async (invocation) => + { + await Task.Yield(); + return new InvocationResponse(new MemoryStream(responseBody)); + }; + + using var bootstrap = new LambdaBootstrap(handler, null); + bootstrap.Client = client; + await bootstrap.InvokeOnceAsync(); + + Assert.False(client.StartStreamingCalled, "StartStreamingResponseAsync should NOT be called for buffered mode"); + Assert.True(client.SendResponseCalled, "SendResponseAsync should be called for buffered mode"); + Assert.Null(client.CapturedHttpBytes); + } + + /// + /// End-to-end: buffered response body is transmitted correctly. + /// Requirements: 1.5, 4.6, 9.4 + /// + [Fact] + public async Task Buffered_ResponseBodyTransmittedCorrectly() + { + var client = CreateClient(); + var responseBody = Encoding.UTF8.GetBytes("hello buffered world"); + + LambdaBootstrapHandler handler = async (invocation) => + { + await Task.Yield(); + return new InvocationResponse(new MemoryStream(responseBody)); + }; + + using var bootstrap = new LambdaBootstrap(handler, null); + bootstrap.Client = client; + await bootstrap.InvokeOnceAsync(); + + Assert.True(client.SendResponseCalled); + Assert.NotNull(client.LastBufferedOutputStream); + var received = new MemoryStream(); + await client.LastBufferedOutputStream.CopyToAsync(received); + Assert.Equal(responseBody, received.ToArray()); + } + + /// + /// End-to-end: midstream error sets error state on ResponseStream with exception details. + /// In production, RawStreamingHttpClient reads this state and writes trailing headers. + /// Requirements: 5.2, 5.3 + /// + [Fact] + public async Task MidstreamError_SetsErrorStateWithExceptionDetails() + { + var client = CreateClient(); + const string errorMessage = "something went wrong mid-stream"; + + LambdaBootstrapHandler handler = async (invocation) => + { + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + await stream.WriteAsync(Encoding.UTF8.GetBytes("some data")); + throw new InvalidOperationException(errorMessage); + }; + + using var bootstrap = new LambdaBootstrap(handler, null); + bootstrap.Client = client; + await bootstrap.InvokeOnceAsync(); + + Assert.True(client.StartStreamingCalled); + Assert.NotNull(client.LastResponseStream); + Assert.True(client.LastResponseStream.HasError); + Assert.NotNull(client.LastResponseStream.ReportedError); + Assert.IsType(client.LastResponseStream.ReportedError); + Assert.Equal(errorMessage, client.LastResponseStream.ReportedError.Message); + + // Verify the handler's data was still captured before the error + var output = Encoding.UTF8.GetString(client.CapturedHttpBytes); + Assert.Contains("some data", output); + } + + // ─── 10.4 Multi-concurrency ────────────────────────────────────────────────── + + /// + /// Multi-concurrency: concurrent invocations use AsyncLocal for state isolation. + /// Each invocation independently uses streaming or buffered mode without interference. + /// Requirements: 2.9, 6.5, 8.9 + /// + [Fact] + public async Task MultiConcurrency_ConcurrentInvocations_StateIsolated() + { + const int concurrency = 3; + var results = new ConcurrentDictionary(); + var barrier = new SemaphoreSlim(0, concurrency); + var allStarted = new SemaphoreSlim(0, concurrency); + + // Simulate concurrent invocations using AsyncLocal directly + var tasks = new List(); + for (int i = 0; i < concurrency; i++) + { + var requestId = $"req-{i}"; + var payload = $"payload-{i}"; + tasks.Add(Task.Run(async () => + { + var mockClient = new MockMultiConcurrencyStreamingClient(); + ResponseStreamFactory.InitializeInvocation( + requestId, + isMultiConcurrency: true, + mockClient, + CancellationToken.None); + + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + allStarted.Release(); + + // Wait until all tasks have started (to ensure true concurrency) + await barrier.WaitAsync(); + + await stream.WriteAsync(Encoding.UTF8.GetBytes(payload)); + stream.MarkCompleted(); + + // Verify this invocation's stream is still accessible + var retrieved = ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: true); + results[requestId] = retrieved != null ? payload : "MISSING"; + + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: true); + })); + } + + // Wait for all tasks to start, then release the barrier + for (int i = 0; i < concurrency; i++) + await allStarted.WaitAsync(); + barrier.Release(concurrency); + + await Task.WhenAll(tasks); + + // Each invocation should have seen its own stream + Assert.Equal(concurrency, results.Count); + for (int i = 0; i < concurrency; i++) + Assert.Equal($"payload-{i}", results[$"req-{i}"]); + } + + /// + /// Multi-concurrency: streaming and buffered invocations can run concurrently without interference. + /// Requirements: 2.9, 6.5, 8.9 + /// + [Fact] + public async Task MultiConcurrency_StreamingAndBufferedMixedConcurrently_NoInterference() + { + var streamingResults = new ConcurrentBag(); + var bufferedResults = new ConcurrentBag(); + var barrier = new SemaphoreSlim(0, 4); + var allStarted = new SemaphoreSlim(0, 4); + + var tasks = new List(); + + // 2 streaming invocations + for (int i = 0; i < 2; i++) + { + var requestId = $"stream-{i}"; + tasks.Add(Task.Run(async () => + { + var mockClient = new MockMultiConcurrencyStreamingClient(); + ResponseStreamFactory.InitializeInvocation( + requestId, + isMultiConcurrency: true, mockClient, CancellationToken.None); + + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + allStarted.Release(); + await barrier.WaitAsync(); + + await stream.WriteAsync(Encoding.UTF8.GetBytes("streaming data")); + stream.MarkCompleted(); + + var retrieved = ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: true); + streamingResults.Add(retrieved != null); + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: true); + })); + } + + // 2 buffered invocations (no CreateStream) + for (int i = 0; i < 2; i++) + { + var requestId = $"buffered-{i}"; + tasks.Add(Task.Run(async () => + { + var mockClient = new MockMultiConcurrencyStreamingClient(); + ResponseStreamFactory.InitializeInvocation( + requestId, + isMultiConcurrency: true, mockClient, CancellationToken.None); + + allStarted.Release(); + await barrier.WaitAsync(); + + // No CreateStream — buffered mode + var retrieved = ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: true); + bufferedResults.Add(retrieved == null); // should be null (no stream created) + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: true); + })); + } + + for (int i = 0; i < 4; i++) + await allStarted.WaitAsync(); + barrier.Release(4); + + await Task.WhenAll(tasks); + + Assert.Equal(2, streamingResults.Count); + Assert.All(streamingResults, r => Assert.True(r, "Streaming invocation should have a stream")); + + Assert.Equal(2, bufferedResults.Count); + Assert.All(bufferedResults, r => Assert.True(r, "Buffered invocation should have no stream")); + } + + /// + /// Minimal mock RuntimeApiClient for multi-concurrency tests. + /// Accepts StartStreamingResponseAsync calls without real HTTP. + /// + private class MockMultiConcurrencyStreamingClient : RuntimeApiClient + { + public MockMultiConcurrencyStreamingClient() + : base(new TestEnvironmentVariables(), new NoOpInternalRuntimeApiClient()) { } + + internal override async Task StartStreamingResponseAsync( + string awsRequestId, ResponseStream responseStream, CancellationToken cancellationToken = default) + { + // Provide the HTTP output stream so writes don't block + await responseStream.SetHttpOutputStreamAsync(new MemoryStream()); + await responseStream.WaitForCompletionAsync(); + return new NoOpDisposable(); + } + } + + // ─── 10.5 Backward compatibility ──────────────────────────────────────────── + + /// + /// Backward compatibility: existing handler signatures (event + ILambdaContext) work without modification. + /// Requirements: 9.1, 9.2, 9.3 + /// + [Fact] + public async Task BackwardCompat_ExistingHandlerSignature_WorksUnchanged() + { + var client = CreateClient(); + bool handlerCalled = false; + + // Simulate a classic handler that returns a buffered response + LambdaBootstrapHandler handler = async (invocation) => + { + handlerCalled = true; + await Task.Yield(); + return new InvocationResponse(new MemoryStream(Encoding.UTF8.GetBytes("classic response"))); + }; + + using var bootstrap = new LambdaBootstrap(handler, null); + bootstrap.Client = client; + await bootstrap.InvokeOnceAsync(); + + Assert.True(handlerCalled); + Assert.True(client.SendResponseCalled); + Assert.False(client.StartStreamingCalled); + } + + /// + /// Backward compatibility: no regression in buffered response behavior — response body is correct. + /// Requirements: 9.4, 9.5 + /// + [Fact] + public async Task BackwardCompat_BufferedResponse_NoRegression() + { + var client = CreateClient(); + var expected = Encoding.UTF8.GetBytes("no regression here"); + + LambdaBootstrapHandler handler = async (invocation) => + { + await Task.Yield(); + return new InvocationResponse(new MemoryStream(expected)); + }; + + using var bootstrap = new LambdaBootstrap(handler, null); + bootstrap.Client = client; + await bootstrap.InvokeOnceAsync(); + + Assert.True(client.SendResponseCalled); + Assert.NotNull(client.LastBufferedOutputStream); + var received = new MemoryStream(); + await client.LastBufferedOutputStream.CopyToAsync(received); + Assert.Equal(expected, received.ToArray()); + } + + /// + /// Backward compatibility: handler that returns null OutputStream still works. + /// Requirements: 9.4 + /// + [Fact] + public async Task BackwardCompat_NullOutputStream_HandledGracefully() + { + var client = CreateClient(); + + LambdaBootstrapHandler handler = async (invocation) => + { + await Task.Yield(); + return new InvocationResponse(Stream.Null, false); + }; + + using var bootstrap = new LambdaBootstrap(handler, null); + bootstrap.Client = client; + + // Should not throw + await bootstrap.InvokeOnceAsync(); + + Assert.True(client.SendResponseCalled); + } + + /// + /// Backward compatibility: handler that throws before CreateStream uses standard error path. + /// Requirements: 9.5 + /// + [Fact] + public async Task BackwardCompat_HandlerThrows_StandardErrorReportingUsed() + { + var client = CreateClient(); + + LambdaBootstrapHandler handler = async (invocation) => + { + await Task.Yield(); + throw new Exception("classic handler error"); + }; + + using var bootstrap = new LambdaBootstrap(handler, null); + bootstrap.Client = client; + await bootstrap.InvokeOnceAsync(); + + Assert.True(client.ReportInvocationErrorCalled); + Assert.False(client.StartStreamingCalled); + } + } +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/TestHelpers/NoOpInternalRuntimeApiClient.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/TestHelpers/NoOpInternalRuntimeApiClient.cs new file mode 100644 index 000000000..9fa0434cd --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/TestHelpers/NoOpInternalRuntimeApiClient.cs @@ -0,0 +1,60 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Amazon.Lambda.RuntimeSupport.UnitTests.TestHelpers +{ + /// + /// A no-op implementation of IInternalRuntimeApiClient for unit tests + /// that need to construct a RuntimeApiClient without real HTTP calls. + /// + internal class NoOpInternalRuntimeApiClient : IInternalRuntimeApiClient + { + private static readonly SwaggerResponse EmptyStatusResponse = + new SwaggerResponse(200, new Dictionary>(), new StatusResponse()); + + public Task> ErrorAsync( + string lambda_Runtime_Function_Error_Type, string errorJson, CancellationToken cancellationToken) + => Task.FromResult(EmptyStatusResponse); + + public Task> NextAsync(CancellationToken cancellationToken) + => Task.FromResult(new SwaggerResponse(200, new Dictionary>(), Stream.Null)); + + public Task> ResponseAsync(string awsRequestId, Stream outputStream) + => Task.FromResult(EmptyStatusResponse); + + public Task> ResponseAsync( + string awsRequestId, Stream outputStream, CancellationToken cancellationToken) + => Task.FromResult(EmptyStatusResponse); + + public Task> ErrorWithXRayCauseAsync( + string awsRequestId, string lambda_Runtime_Function_Error_Type, + string errorJson, string xrayCause, CancellationToken cancellationToken) + => Task.FromResult(EmptyStatusResponse); + +#if NET8_0_OR_GREATER + public Task> RestoreNextAsync(CancellationToken cancellationToken) + => Task.FromResult(new SwaggerResponse(200, new Dictionary>(), Stream.Null)); + + public Task> RestoreErrorAsync( + string lambda_Runtime_Function_Error_Type, string errorJson, CancellationToken cancellationToken) + => Task.FromResult(EmptyStatusResponse); +#endif + } +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/TestHelpers/TestStreamingRuntimeApiClient.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/TestHelpers/TestStreamingRuntimeApiClient.cs new file mode 100644 index 000000000..1cd6fa09e --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/TestHelpers/TestStreamingRuntimeApiClient.cs @@ -0,0 +1,142 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; +using Amazon.Lambda.RuntimeSupport.Helpers; +using Amazon.Lambda.RuntimeSupport.UnitTests.TestHelpers; +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Amazon.Lambda.RuntimeSupport.UnitTests +{ + /// + /// A RuntimeApiClient subclass for testing LambdaBootstrap streaming integration. + /// Extends RuntimeApiClient so the (RuntimeApiClient)Client cast in LambdaBootstrap works. + /// Overrides StartStreamingResponseAsync to avoid real HTTP calls. + /// + internal class TestStreamingRuntimeApiClient : RuntimeApiClient, IRuntimeApiClient + { + private readonly IEnvironmentVariables _environmentVariables; + private readonly Dictionary> _headers; + + public new IConsoleLoggerWriter ConsoleLogger { get; } = new LogLevelLoggerWriter(new SystemEnvironmentVariables()); + + public TestStreamingRuntimeApiClient(IEnvironmentVariables environmentVariables, Dictionary> headers) + : base(environmentVariables, new NoOpInternalRuntimeApiClient()) + { + _environmentVariables = environmentVariables; + _headers = headers; + } + + // Tracking flags + public bool GetNextInvocationAsyncCalled { get; private set; } + public bool ReportInitializationErrorAsyncExceptionCalled { get; private set; } + public bool ReportInvocationErrorAsyncExceptionCalled { get; private set; } + public bool SendResponseAsyncCalled { get; private set; } + public bool StartStreamingResponseAsyncCalled { get; private set; } + + public string LastTraceId { get; private set; } + public byte[] FunctionInput { get; set; } + public Stream LastOutputStream { get; private set; } + public Exception LastRecordedException { get; private set; } + public ResponseStream LastStreamingResponseStream { get; private set; } + + public new async Task GetNextInvocationAsync(CancellationToken cancellationToken = default) + { + GetNextInvocationAsyncCalled = true; + + LastTraceId = Guid.NewGuid().ToString(); + _headers[RuntimeApiHeaders.HeaderTraceId] = new List() { LastTraceId }; + + var inputStream = new MemoryStream(FunctionInput == null ? new byte[0] : FunctionInput); + inputStream.Position = 0; + + return new InvocationRequest() + { + InputStream = inputStream, + LambdaContext = new LambdaContext( + new RuntimeApiHeaders(_headers), + new LambdaEnvironment(_environmentVariables), + new TestDateTimeHelper(), new SimpleLoggerWriter(_environmentVariables)) + }; + } + + public new Task ReportInitializationErrorAsync(Exception exception, String errorType = null, CancellationToken cancellationToken = default) + { + LastRecordedException = exception; + ReportInitializationErrorAsyncExceptionCalled = true; + return Task.CompletedTask; + } + + public new Task ReportInitializationErrorAsync(string errorType, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } + + public new Task ReportInvocationErrorAsync(string awsRequestId, Exception exception, CancellationToken cancellationToken = default) + { + LastRecordedException = exception; + ReportInvocationErrorAsyncExceptionCalled = true; + return Task.CompletedTask; + } + + public new async Task SendResponseAsync(string awsRequestId, Stream outputStream, CancellationToken cancellationToken = default) + { + if (outputStream != null) + { + LastOutputStream = new MemoryStream((int)outputStream.Length); + outputStream.CopyTo(LastOutputStream); + LastOutputStream.Position = 0; + } + + SendResponseAsyncCalled = true; + } + + internal override async Task StartStreamingResponseAsync( + string awsRequestId, ResponseStream responseStream, CancellationToken cancellationToken = default) + { + StartStreamingResponseAsyncCalled = true; + LastStreamingResponseStream = responseStream; + + // Simulate the HTTP stream being available + await responseStream.SetHttpOutputStreamAsync(new MemoryStream(), cancellationToken); + + // Wait for the handler to finish writing (mirrors real SerializeToStreamAsync behavior) + await responseStream.WaitForCompletionAsync(); + + return new NoOpDisposable(); + } + +#if NET8_0_OR_GREATER + public new Task RestoreNextInvocationAsync(CancellationToken cancellationToken = default) + => Task.CompletedTask; + + public new Task ReportRestoreErrorAsync(Exception exception, String errorType = null, CancellationToken cancellationToken = default) + => Task.CompletedTask; +#endif + } + + /// + /// A no-op IDisposable for test overrides of StartStreamingResponseAsync. + /// + internal class NoOpDisposable : IDisposable + { + public void Dispose() { } + } +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers/Function.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers/Function.cs new file mode 100644 index 000000000..8c645ff5b --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers/Function.cs @@ -0,0 +1,56 @@ +#pragma warning disable CA2252 + +using Amazon.Lambda.Core; +using Amazon.Lambda.Core.ResponseStreaming; +using Amazon.Lambda.RuntimeSupport; +using Amazon.Lambda.Serialization.SystemTextJson; + +// The function handler that will be called for each Lambda event +var handler = async (string input, ILambdaContext context) => +{ + using var stream = LambdaResponseStreamFactory.CreateStream(); + + switch(input) + { + case $"{nameof(SimpleFunctionHandler)}": + await SimpleFunctionHandler(stream, context); + break; + case $"{nameof(StreamContentHandler)}": + await StreamContentHandler(stream, context); + break; + case $"{nameof(UnhandledExceptionHandler)}": + await UnhandledExceptionHandler(stream, context); + break; + default: + throw new ArgumentException($"Unknown handler scenario {input}"); + } +}; + +async Task SimpleFunctionHandler(Stream stream, ILambdaContext context) +{ + using var writer = new StreamWriter(stream); + await writer.WriteAsync("Hello, World!"); +} + +async Task StreamContentHandler(Stream stream, ILambdaContext context) +{ + using var writer = new StreamWriter(stream); + + await writer.WriteLineAsync("Starting stream content..."); + for(var i = 0; i < 10000; i++) + { + await writer.WriteLineAsync($"Line {i}"); + } + await writer.WriteLineAsync("Finish stream content"); +} + +async Task UnhandledExceptionHandler(Stream stream, ILambdaContext context) +{ + using var writer = new StreamWriter(stream); + await writer.WriteAsync("This method will fail"); + throw new InvalidOperationException("This is an unhandled exception"); +} + +await LambdaBootstrapBuilder.Create(handler, new DefaultLambdaJsonSerializer()) + .Build() + .RunAsync(); diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers/ResponseStreamingFunctionHandlers.csproj b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers/ResponseStreamingFunctionHandlers.csproj new file mode 100644 index 000000000..fa81eaa17 --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers/ResponseStreamingFunctionHandlers.csproj @@ -0,0 +1,19 @@ + + + Exe + net10.0 + enable + enable + true + Lambda + + true + + true + + + + + + + \ No newline at end of file diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers/aws-lambda-tools-defaults.json b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers/aws-lambda-tools-defaults.json new file mode 100644 index 000000000..3042c3978 --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers/aws-lambda-tools-defaults.json @@ -0,0 +1,15 @@ +{ + "Information": [ + "This file provides default values for the deployment wizard inside Visual Studio and the AWS Lambda commands added to the .NET Core CLI.", + "To learn more about the Lambda commands with the .NET Core CLI execute the following command at the command line in the project root directory.", + "dotnet lambda help", + "All the command line options for the Lambda command can be specified in this file." + ], + "profile": "default", + "region": "us-west-2", + "configuration": "Release", + "function-runtime": "dotnet10", + "function-memory-size": 512, + "function-timeout": 30, + "function-handler": "ResponseStreamingFunctionHandlers" +} \ No newline at end of file diff --git a/Libraries/test/IntegrationTests.Helpers/LambdaHelper.cs b/Libraries/test/IntegrationTests.Helpers/LambdaHelper.cs index 6436c7c7b..591eb5e06 100644 --- a/Libraries/test/IntegrationTests.Helpers/LambdaHelper.cs +++ b/Libraries/test/IntegrationTests.Helpers/LambdaHelper.cs @@ -1,3 +1,6 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + using System.Collections.Generic; using System.Threading.Tasks; using Amazon.Lambda; @@ -55,6 +58,14 @@ public async Task ListEventSourceMappingsAsync( }); } + public async Task GetFunctionUrlConfigAsync(string functionName) + { + return await _lambdaClient.GetFunctionUrlConfigAsync(new GetFunctionUrlConfigRequest + { + FunctionName = functionName + }); + } + public async Task WaitTillNotPending(List functions) { foreach (var function in functions) diff --git a/Libraries/test/IntegrationTests.Helpers/S3Helper.cs b/Libraries/test/IntegrationTests.Helpers/S3Helper.cs index e3c09eefd..9620732b7 100644 --- a/Libraries/test/IntegrationTests.Helpers/S3Helper.cs +++ b/Libraries/test/IntegrationTests.Helpers/S3Helper.cs @@ -33,5 +33,13 @@ public async Task BucketExistsAsync(string bucketName) var response = await _s3Client.ListBucketsAsync(new ListBucketsRequest()); return response.Buckets.Any(x => x.BucketName.Equals(bucketName)); } + + public async Task GetBucketNotificationAsync(string bucketName) + { + return await _s3Client.GetBucketNotificationAsync(new GetBucketNotificationRequest + { + BucketName = bucketName + }); + } } } diff --git a/Libraries/test/TestCustomAuthorizerApp/serverless.template b/Libraries/test/TestCustomAuthorizerApp/serverless.template index 41e03a2f9..d50b5f6ef 100644 --- a/Libraries/test/TestCustomAuthorizerApp/serverless.template +++ b/Libraries/test/TestCustomAuthorizerApp/serverless.template @@ -1,7 +1,7 @@ { "AWSTemplateFormatVersion": "2010-09-09", "Transform": "AWS::Serverless-2016-10-31", - "Description": "This template is partially managed by Amazon.Lambda.Annotations (v1.10.0.0).", + "Description": "This template is partially managed by Amazon.Lambda.Annotations (v1.13.0.0).", "Resources": { "AnnotationsHttpApi": { "Type": "AWS::Serverless::HttpApi", diff --git a/Libraries/test/TestExecutableServerlessApp/serverless.template b/Libraries/test/TestExecutableServerlessApp/serverless.template index 3092da266..a4112a9eb 100644 --- a/Libraries/test/TestExecutableServerlessApp/serverless.template +++ b/Libraries/test/TestExecutableServerlessApp/serverless.template @@ -1,7 +1,7 @@ { "AWSTemplateFormatVersion": "2010-09-09", "Transform": "AWS::Serverless-2016-10-31", - "Description": "An AWS Serverless Application. This template is partially managed by Amazon.Lambda.Annotations (v1.10.0.0).", + "Description": "An AWS Serverless Application. This template is partially managed by Amazon.Lambda.Annotations (v1.13.0.0).", "Parameters": { "ArchitectureTypeParameter": { "Type": "String", diff --git a/Libraries/test/TestServerlessApp.ALB.IntegrationTests/ALBIntegrationTestContextFixture.cs b/Libraries/test/TestServerlessApp.ALB.IntegrationTests/ALBIntegrationTestContextFixture.cs new file mode 100644 index 000000000..40c70f7be --- /dev/null +++ b/Libraries/test/TestServerlessApp.ALB.IntegrationTests/ALBIntegrationTestContextFixture.cs @@ -0,0 +1,172 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using System; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Threading.Tasks; +using Amazon.CloudFormation; +using Amazon.ElasticLoadBalancingV2; +using Amazon.ElasticLoadBalancingV2.Model; +using Amazon.Lambda; +using Amazon.S3; +using IntegrationTests.Helpers; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using Xunit; + +namespace TestServerlessApp.ALB.IntegrationTests +{ + public class ALBIntegrationTestContextFixture : IAsyncLifetime + { + private readonly CloudFormationHelper _cloudFormationHelper; + private readonly S3Helper _s3Helper; + + private string _stackName; + private string _bucketName; + + public readonly AmazonElasticLoadBalancingV2Client ELBv2Client; + public readonly LambdaHelper LambdaHelper; + public readonly HttpClient HttpClient; + + public string ALBDnsName; + public string LoadBalancerArn; + + public ALBIntegrationTestContextFixture() + { + _cloudFormationHelper = new CloudFormationHelper(new AmazonCloudFormationClient(Amazon.RegionEndpoint.USWest2)); + _s3Helper = new S3Helper(new AmazonS3Client(Amazon.RegionEndpoint.USWest2)); + LambdaHelper = new LambdaHelper(new AmazonLambdaClient(Amazon.RegionEndpoint.USWest2)); + ELBv2Client = new AmazonElasticLoadBalancingV2Client(Amazon.RegionEndpoint.USWest2); + HttpClient = new HttpClient(); + } + + public async Task InitializeAsync() + { + var scriptPath = Path.Combine("..", "..", "..", "DeploymentScript.ps1"); + Console.WriteLine($"[ALB IntegrationTest] Running deployment script: {scriptPath}"); + await CommandLineWrapper.RunAsync($"pwsh {scriptPath}"); + Console.WriteLine("[ALB IntegrationTest] Deployment script completed successfully."); + + _stackName = GetConfigValue("stack-name"); + _bucketName = GetConfigValue("s3-bucket"); + Console.WriteLine($"[ALB IntegrationTest] Stack name: '{_stackName}', Bucket name: '{_bucketName}'"); + Assert.False(string.IsNullOrEmpty(_stackName), "Stack name should not be empty"); + Assert.False(string.IsNullOrEmpty(_bucketName), "Bucket name should not be empty"); + + // Check stack status + var stackStatus = await _cloudFormationHelper.GetStackStatusAsync(_stackName); + Console.WriteLine($"[ALB IntegrationTest] Stack status: {stackStatus}"); + Assert.NotNull(stackStatus); + Assert.Equal(StackStatus.CREATE_COMPLETE, stackStatus); + + // Get ALB DNS name from stack outputs + ALBDnsName = await _cloudFormationHelper.GetOutputValueAsync(_stackName, "ALBDnsName"); + Console.WriteLine($"[ALB IntegrationTest] ALB DNS Name: {ALBDnsName}"); + Assert.False(string.IsNullOrEmpty(ALBDnsName), "ALB DNS Name should not be empty"); + + // Resolve the LoadBalancerArn from DNS name for scoped queries + var lbResponse = await ELBv2Client.DescribeLoadBalancersAsync(new DescribeLoadBalancersRequest()); + var loadBalancer = lbResponse.LoadBalancers.FirstOrDefault(lb => lb.DNSName == ALBDnsName); + if (loadBalancer != null) + { + LoadBalancerArn = loadBalancer.LoadBalancerArn; + Console.WriteLine($"[ALB IntegrationTest] LoadBalancer ARN: {LoadBalancerArn}"); + } + + // Wait for Lambda targets to become healthy by polling target health + Console.WriteLine("[ALB IntegrationTest] Waiting for targets to become healthy..."); + await WaitForTargetsHealthy(timeoutSeconds: 120, pollIntervalSeconds: 10); + } + + /// + /// Polls ALB target group health until at least one target is healthy or the timeout is reached. + /// + private async Task WaitForTargetsHealthy(int timeoutSeconds, int pollIntervalSeconds) + { + var deadline = DateTime.UtcNow.AddSeconds(timeoutSeconds); + + while (DateTime.UtcNow < deadline) + { + try + { + if (!string.IsNullOrEmpty(LoadBalancerArn)) + { + var tgResponse = await ELBv2Client.DescribeTargetGroupsAsync(new DescribeTargetGroupsRequest + { + LoadBalancerArn = LoadBalancerArn + }); + + var lambdaTgs = tgResponse.TargetGroups.Where(tg => tg.TargetType == TargetTypeEnum.Lambda).ToList(); + if (lambdaTgs.Count >= 2) + { + var allHealthy = true; + foreach (var tg in lambdaTgs) + { + var healthResponse = await ELBv2Client.DescribeTargetHealthAsync(new DescribeTargetHealthRequest + { + TargetGroupArn = tg.TargetGroupArn + }); + if (!healthResponse.TargetHealthDescriptions.Any(t => t.TargetHealth.State == TargetHealthStateEnum.Healthy)) + { + allHealthy = false; + break; + } + } + + if (allHealthy) + { + Console.WriteLine("[ALB IntegrationTest] All targets are healthy."); + return; + } + } + } + } + catch (Exception ex) + { + Console.WriteLine($"[ALB IntegrationTest] Polling error (will retry): {ex.Message}"); + } + + Console.WriteLine($"[ALB IntegrationTest] Targets not yet healthy, retrying in {pollIntervalSeconds}s..."); + await Task.Delay(pollIntervalSeconds * 1000); + } + + Console.WriteLine("[ALB IntegrationTest] Warning: Timed out waiting for targets to become healthy. Proceeding anyway."); + } + + public async Task DisposeAsync() + { + if (!string.IsNullOrEmpty(_stackName)) + { + Console.WriteLine($"[ALB IntegrationTest] Cleaning up stack '{_stackName}'..."); + await _cloudFormationHelper.DeleteStackAsync(_stackName); + Assert.True(await _cloudFormationHelper.IsDeletedAsync(_stackName), + $"The stack '{_stackName}' still exists and will have to be manually deleted."); + } + + if (!string.IsNullOrEmpty(_bucketName)) + { + Console.WriteLine($"[ALB IntegrationTest] Cleaning up bucket '{_bucketName}'..."); + await _s3Helper.DeleteBucketAsync(_bucketName); + Assert.False(await _s3Helper.BucketExistsAsync(_bucketName), + $"The bucket '{_bucketName}' still exists and will have to be manually deleted."); + } + + // Reset aws-lambda-tools-defaults.json to original values + var filePath = Path.Combine("..", "..", "..", "..", "TestServerlessApp.ALB", "aws-lambda-tools-defaults.json"); + var token = JObject.Parse(await File.ReadAllTextAsync(filePath)); + token["s3-bucket"] = "test-serverless-app-alb"; + token["stack-name"] = "test-serverless-app-alb"; + token["function-architecture"] = "x86_64"; + await File.WriteAllTextAsync(filePath, token.ToString(Formatting.Indented)); + } + + private string GetConfigValue(string key) + { + var filePath = Path.Combine("..", "..", "..", "..", "TestServerlessApp.ALB", "aws-lambda-tools-defaults.json"); + var token = JObject.Parse(File.ReadAllText(filePath))[key]; + return token?.ToObject(); + } + } +} diff --git a/Libraries/test/TestServerlessApp.ALB.IntegrationTests/ALBIntegrationTestContextFixtureCollection.cs b/Libraries/test/TestServerlessApp.ALB.IntegrationTests/ALBIntegrationTestContextFixtureCollection.cs new file mode 100644 index 000000000..696048ab4 --- /dev/null +++ b/Libraries/test/TestServerlessApp.ALB.IntegrationTests/ALBIntegrationTestContextFixtureCollection.cs @@ -0,0 +1,12 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using Xunit; + +namespace TestServerlessApp.ALB.IntegrationTests +{ + [CollectionDefinition("ALB Integration Tests")] + public class ALBIntegrationTestContextFixtureCollection : ICollectionFixture + { + } +} diff --git a/Libraries/test/TestServerlessApp.ALB.IntegrationTests/ALBTargetTests.cs b/Libraries/test/TestServerlessApp.ALB.IntegrationTests/ALBTargetTests.cs new file mode 100644 index 000000000..8e90c1a79 --- /dev/null +++ b/Libraries/test/TestServerlessApp.ALB.IntegrationTests/ALBTargetTests.cs @@ -0,0 +1,81 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using System.Linq; +using System.Net; +using System.Threading.Tasks; +using Amazon.ElasticLoadBalancingV2; +using Amazon.ElasticLoadBalancingV2.Model; +using Xunit; + +namespace TestServerlessApp.ALB.IntegrationTests +{ + [Collection("ALB Integration Tests")] + public class ALBTargetTests + { + private readonly ALBIntegrationTestContextFixture _fixture; + + public ALBTargetTests(ALBIntegrationTestContextFixture fixture) + { + _fixture = fixture; + } + + [Fact] + public async Task InvokeHelloEndpoint_ReturnsSuccessWithBody() + { + // ACT + var response = await _fixture.HttpClient.GetAsync($"http://{_fixture.ALBDnsName}/hello"); + + // ASSERT + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + var body = await response.Content.ReadAsStringAsync(); + Assert.Contains("Hello from ALB Lambda!", body); + Assert.Contains("/hello", body); + } + + [Fact] + public async Task InvokeHealthEndpoint_ReturnsHealthy() + { + // ACT + var response = await _fixture.HttpClient.GetAsync($"http://{_fixture.ALBDnsName}/health"); + + // ASSERT + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + var body = await response.Content.ReadAsStringAsync(); + Assert.Contains("healthy", body); + } + + [Fact] + public async Task InvokeUnknownPath_Returns404FromDefaultAction() + { + // ACT - The ALB default action returns 404 for unmatched paths + var response = await _fixture.HttpClient.GetAsync($"http://{_fixture.ALBDnsName}/unknown-path"); + + // ASSERT + Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); + var body = await response.Content.ReadAsStringAsync(); + Assert.Contains("Not Found", body); + } + + [Fact] + public async Task VerifyTargetGroupsExist() + { + // ACT - Describe only the target groups associated with this stack's load balancer + Assert.False(string.IsNullOrEmpty(_fixture.LoadBalancerArn), + "LoadBalancerArn should have been resolved during test initialization"); + + var describeResponse = await _fixture.ELBv2Client.DescribeTargetGroupsAsync(new DescribeTargetGroupsRequest + { + LoadBalancerArn = _fixture.LoadBalancerArn + }); + + var albTargetGroups = describeResponse.TargetGroups + .Where(tg => tg.TargetType == TargetTypeEnum.Lambda) + .ToList(); + + // ASSERT - At least our Lambda target groups should exist for this ALB + Assert.True(albTargetGroups.Count >= 2, + $"Expected at least 2 Lambda target groups for ALB '{_fixture.ALBDnsName}', found {albTargetGroups.Count}"); + } + } +} diff --git a/Libraries/test/TestServerlessApp.ALB.IntegrationTests/DeploymentScript.ps1 b/Libraries/test/TestServerlessApp.ALB.IntegrationTests/DeploymentScript.ps1 new file mode 100644 index 000000000..f74ee365f --- /dev/null +++ b/Libraries/test/TestServerlessApp.ALB.IntegrationTests/DeploymentScript.ps1 @@ -0,0 +1,85 @@ +$ErrorActionPreference = 'Stop' + +function Get-Architecture { + $arch = [System.Runtime.InteropServices.RuntimeInformation]::OSArchitecture + if ($arch -eq "Arm64" || $arch -eq "Arm") { + return "arm64" + } + + if ($arch -eq "X64" || $arch -eq "X86") { + return "x86_64" + } + + throw "Unsupported architecture: $arch" +} + +try +{ + Push-Location $PSScriptRoot + $guid = New-Guid + $suffix = $guid.ToString().Split('-') | Select-Object -First 1 + $identifier = "test-alb-app-" + $suffix + cd ..\TestServerlessApp.ALB + + $arch = Get-Architecture + + # Replace bucket name in aws-lambda-tools-defaults.json + $line = Get-Content .\aws-lambda-tools-defaults.json | Select-String s3-bucket | Select-Object -ExpandProperty Line + $content = Get-Content .\aws-lambda-tools-defaults.json + $content | ForEach-Object {$_ -replace $line, "`"s3-bucket`" : `"$identifier`","} | Set-Content .\aws-lambda-tools-defaults.json + + # Replace stack name in aws-lambda-tools-defaults.json + $line = Get-Content .\aws-lambda-tools-defaults.json | Select-String stack-name | Select-Object -ExpandProperty Line + $content = Get-Content .\aws-lambda-tools-defaults.json + $content | ForEach-Object {$_ -replace $line, "`"stack-name`" : `"$identifier`","} | Set-Content .\aws-lambda-tools-defaults.json + + # Replace function-architecture in aws-lambda-tools-defaults.json + $line = Get-Content .\aws-lambda-tools-defaults.json | Select-String function-architecture | Select-Object -ExpandProperty Line + $content = Get-Content .\aws-lambda-tools-defaults.json + $content | ForEach-Object {$_ -replace $line, "`"function-architecture`" : `"$arch`""} | Set-Content .\aws-lambda-tools-defaults.json + + # Extract region + $json = Get-Content .\aws-lambda-tools-defaults.json | Out-String | ConvertFrom-Json + $region = $json.region + + dotnet tool install -g Amazon.Lambda.Tools + Write-Host "Creating S3 Bucket $identifier" + + if(![string]::IsNullOrEmpty($region)) + { + aws s3 mb s3://$identifier --region $region + } + else + { + aws s3 mb s3://$identifier + } + + if (!$?) + { + throw "Failed to create the following bucket: $identifier" + } + + dotnet restore + Write-Host "Creating CloudFormation Stack $identifier, Architecture $arch" + dotnet lambda deploy-serverless + if (!$?) + { + Write-Host "Deployment failed. Fetching CloudFormation stack events for debugging..." + try { + $events = aws cloudformation describe-stack-events --stack-name $identifier --query "StackEvents[?ResourceStatus=='CREATE_FAILED' || ResourceStatus=='UPDATE_FAILED']" --output json 2>&1 + if ($events) { + Write-Host "CloudFormation failed events:" + Write-Host $events + } + } + catch { + Write-Host "Could not fetch CloudFormation events: $_" + } + + throw "Failed to create the following CloudFormation stack: $identifier" + } +} +finally +{ + Pop-Location +} diff --git a/Libraries/test/TestServerlessApp.ALB.IntegrationTests/TestServerlessApp.ALB.IntegrationTests.csproj b/Libraries/test/TestServerlessApp.ALB.IntegrationTests/TestServerlessApp.ALB.IntegrationTests.csproj new file mode 100644 index 000000000..ceb0564e5 --- /dev/null +++ b/Libraries/test/TestServerlessApp.ALB.IntegrationTests/TestServerlessApp.ALB.IntegrationTests.csproj @@ -0,0 +1,20 @@ + + + net6.0 + false + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + diff --git a/Libraries/test/TestServerlessApp.ALB/ALBFunctions.cs b/Libraries/test/TestServerlessApp.ALB/ALBFunctions.cs new file mode 100644 index 000000000..325a130e1 --- /dev/null +++ b/Libraries/test/TestServerlessApp.ALB/ALBFunctions.cs @@ -0,0 +1,109 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using Amazon.Lambda.Annotations; +using Amazon.Lambda.Annotations.ALB; +using Amazon.Lambda.ApplicationLoadBalancerEvents; +using Amazon.Lambda.Core; +using System.Collections.Generic; + +namespace TestServerlessApp.ALB +{ + public class ALBFunctions + { + /// + /// Hello endpoint - returns a greeting message with the request path. + /// Uses the raw ApplicationLoadBalancerRequest (pass-through mode). + /// + [LambdaFunction(ResourceName = "ALBHello", MemorySize = 256, Timeout = 15)] + [ALBApi("@ALBTestListener", "/hello", 1)] + public ApplicationLoadBalancerResponse Hello(ApplicationLoadBalancerRequest request, ILambdaContext context) + { + context.Logger.LogInformation($"Hello endpoint hit. Path: {request.Path}"); + + return new ApplicationLoadBalancerResponse + { + StatusCode = 200, + StatusDescription = "200 OK", + IsBase64Encoded = false, + Headers = new Dictionary + { + { "Content-Type", "application/json" } + }, + Body = $"{{\"message\": \"Hello from ALB Lambda!\", \"path\": \"{request.Path}\"}}" + }; + } + + /// + /// Health check endpoint for ALB target group health checks. + /// Uses the raw ApplicationLoadBalancerRequest (pass-through mode). + /// + [LambdaFunction(ResourceName = "ALBHealth", MemorySize = 128, Timeout = 5)] + [ALBApi("@ALBTestListener", "/health", 2)] + public ApplicationLoadBalancerResponse Health(ApplicationLoadBalancerRequest request, ILambdaContext context) + { + return new ApplicationLoadBalancerResponse + { + StatusCode = 200, + StatusDescription = "200 OK", + IsBase64Encoded = false, + Headers = new Dictionary + { + { "Content-Type", "application/json" } + }, + Body = "{\"status\": \"healthy\"}" + }; + } + + /// + /// Greeting endpoint that uses FromQuery and FromHeader parameter binding. + /// Demonstrates ALB functions with any number of parameters using FromX attributes. + /// + [LambdaFunction(ResourceName = "ALBGreeting", MemorySize = 256, Timeout = 15)] + [ALBApi("@ALBTestListener", "/greeting", 3)] + public ApplicationLoadBalancerResponse Greeting( + [FromQuery] string name, + [FromHeader(Name = "X-Custom-Header")] string customHeader, + ILambdaContext context) + { + context.Logger.LogInformation($"Greeting endpoint hit. Name: {name}, Header: {customHeader}"); + + return new ApplicationLoadBalancerResponse + { + StatusCode = 200, + StatusDescription = "200 OK", + IsBase64Encoded = false, + Headers = new Dictionary + { + { "Content-Type", "application/json" } + }, + Body = $"{{\"message\": \"Hello {name}!\", \"customHeader\": \"{customHeader}\"}}" + }; + } + + /// + /// Endpoint that uses FromBody to deserialize JSON request body. + /// Demonstrates ALB function with body deserialization. + /// + [LambdaFunction(ResourceName = "ALBCreateItem", MemorySize = 256, Timeout = 15)] + [ALBApi("@ALBTestListener", "/items", 4, HttpMethod = "POST")] + public ApplicationLoadBalancerResponse CreateItem( + [FromBody] string body, + ILambdaContext context) + { + context.Logger.LogInformation($"CreateItem endpoint hit. Body: {body}"); + + return new ApplicationLoadBalancerResponse + { + StatusCode = 201, + StatusDescription = "201 Created", + IsBase64Encoded = false, + Headers = new Dictionary + { + { "Content-Type", "application/json" } + }, + Body = $"{{\"created\": true, \"body\": \"{body}\"}}" + }; + } + } +} diff --git a/Libraries/test/TestServerlessApp.ALB/AssemblyAttributes.cs b/Libraries/test/TestServerlessApp.ALB/AssemblyAttributes.cs new file mode 100644 index 000000000..a0821fb25 --- /dev/null +++ b/Libraries/test/TestServerlessApp.ALB/AssemblyAttributes.cs @@ -0,0 +1,6 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using Amazon.Lambda.Core; + +[assembly: LambdaSerializer(typeof(Amazon.Lambda.Serialization.SystemTextJson.DefaultLambdaJsonSerializer))] diff --git a/Libraries/test/TestServerlessApp.ALB/TestServerlessApp.ALB.csproj b/Libraries/test/TestServerlessApp.ALB/TestServerlessApp.ALB.csproj new file mode 100644 index 000000000..9f9e5630c --- /dev/null +++ b/Libraries/test/TestServerlessApp.ALB/TestServerlessApp.ALB.csproj @@ -0,0 +1,15 @@ + + + net6.0 + true + Lambda + true + + + + + + + + + diff --git a/Libraries/test/TestServerlessApp.ALB/aws-lambda-tools-defaults.json b/Libraries/test/TestServerlessApp.ALB/aws-lambda-tools-defaults.json new file mode 100644 index 000000000..49464b51c --- /dev/null +++ b/Libraries/test/TestServerlessApp.ALB/aws-lambda-tools-defaults.json @@ -0,0 +1,17 @@ +{ + "Information": [ + "This file provides default values for the deployment wizard inside Visual Studio and the AWS Lambda commands added to the .NET Core CLI.", + "To learn more about the Lambda commands with the .NET Core CLI execute the following command at the command line in the project root directory.", + "dotnet lambda help", + "All the command line options for the Lambda command can be specified in this file." + ], + "profile": "", + "region": "us-west-2", + "configuration": "Release", + "framework": "net6.0", +"s3-bucket" : "test-alb-app-dce31eae", +"stack-name" : "test-alb-app-dce31eae", + "template": "serverless.template", + "template-parameters": "", +"function-architecture" : "x86_64" +} diff --git a/Libraries/test/TestServerlessApp.ALB/serverless.template b/Libraries/test/TestServerlessApp.ALB/serverless.template new file mode 100644 index 000000000..e07e2e226 --- /dev/null +++ b/Libraries/test/TestServerlessApp.ALB/serverless.template @@ -0,0 +1,542 @@ +{ + "AWSTemplateFormatVersion": "2010-09-09", + "Transform": "AWS::Serverless-2016-10-31", + "Description": "ALB Integration Test Stack - VPC and ALB infrastructure for testing Lambda ALB annotations This template is partially managed by Amazon.Lambda.Annotations (v1.13.0.0).", + "Resources": { + "ALBTestVPC": { + "Type": "AWS::EC2::VPC", + "Properties": { + "CidrBlock": "10.0.0.0/16", + "EnableDnsSupport": true, + "EnableDnsHostnames": true, + "Tags": [ + { + "Key": "Name", + "Value": "ALB-Integration-Test-VPC" + } + ] + } + }, + "ALBTestInternetGateway": { + "Type": "AWS::EC2::InternetGateway" + }, + "ALBTestGatewayAttachment": { + "Type": "AWS::EC2::VPCGatewayAttachment", + "Properties": { + "VpcId": { + "Ref": "ALBTestVPC" + }, + "InternetGatewayId": { + "Ref": "ALBTestInternetGateway" + } + } + }, + "ALBTestRouteTable": { + "Type": "AWS::EC2::RouteTable", + "Properties": { + "VpcId": { + "Ref": "ALBTestVPC" + } + } + }, + "ALBTestRoute": { + "Type": "AWS::EC2::Route", + "DependsOn": "ALBTestGatewayAttachment", + "Properties": { + "RouteTableId": { + "Ref": "ALBTestRouteTable" + }, + "DestinationCidrBlock": "0.0.0.0/0", + "GatewayId": { + "Ref": "ALBTestInternetGateway" + } + } + }, + "ALBTestSubnet1": { + "Type": "AWS::EC2::Subnet", + "Properties": { + "VpcId": { + "Ref": "ALBTestVPC" + }, + "CidrBlock": "10.0.1.0/24", + "AvailabilityZone": { + "Fn::Select": [ + 0, + { + "Fn::GetAZs": "" + } + ] + }, + "MapPublicIpOnLaunch": true + } + }, + "ALBTestSubnet2": { + "Type": "AWS::EC2::Subnet", + "Properties": { + "VpcId": { + "Ref": "ALBTestVPC" + }, + "CidrBlock": "10.0.2.0/24", + "AvailabilityZone": { + "Fn::Select": [ + 1, + { + "Fn::GetAZs": "" + } + ] + }, + "MapPublicIpOnLaunch": true + } + }, + "ALBTestSubnet1RouteTableAssociation": { + "Type": "AWS::EC2::SubnetRouteTableAssociation", + "Properties": { + "SubnetId": { + "Ref": "ALBTestSubnet1" + }, + "RouteTableId": { + "Ref": "ALBTestRouteTable" + } + } + }, + "ALBTestSubnet2RouteTableAssociation": { + "Type": "AWS::EC2::SubnetRouteTableAssociation", + "Properties": { + "SubnetId": { + "Ref": "ALBTestSubnet2" + }, + "RouteTableId": { + "Ref": "ALBTestRouteTable" + } + } + }, + "ALBTestSecurityGroup": { + "Type": "AWS::EC2::SecurityGroup", + "Properties": { + "GroupDescription": "ALB Integration Test Security Group", + "VpcId": { + "Ref": "ALBTestVPC" + }, + "SecurityGroupIngress": [ + { + "IpProtocol": "tcp", + "FromPort": 80, + "ToPort": 80, + "CidrIp": "0.0.0.0/0" + } + ] + } + }, + "ALBTestLoadBalancer": { + "Type": "AWS::ElasticLoadBalancingV2::LoadBalancer", + "Properties": { + "Type": "application", + "Scheme": "internet-facing", + "Subnets": [ + { + "Ref": "ALBTestSubnet1" + }, + { + "Ref": "ALBTestSubnet2" + } + ], + "SecurityGroups": [ + { + "Ref": "ALBTestSecurityGroup" + } + ] + } + }, + "ALBTestListener": { + "Type": "AWS::ElasticLoadBalancingV2::Listener", + "Properties": { + "LoadBalancerArn": { + "Ref": "ALBTestLoadBalancer" + }, + "Port": 80, + "Protocol": "HTTP", + "DefaultActions": [ + { + "Type": "fixed-response", + "FixedResponseConfig": { + "StatusCode": "404", + "ContentType": "application/json", + "MessageBody": "{\"error\": \"Not Found\"}" + } + } + ] + } + }, + "ALBHello": { + "Type": "AWS::Serverless::Function", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations", + "SyncedAlbResources": [ + "ALBHelloALBPermission", + "ALBHelloALBTargetGroup", + "ALBHelloALBListenerRule" + ] + }, + "Properties": { + "Runtime": "dotnet6", + "CodeUri": ".", + "MemorySize": 256, + "Timeout": 15, + "Policies": [ + "AWSLambdaBasicExecutionRole" + ], + "PackageType": "Zip", + "Handler": "TestServerlessApp.ALB::TestServerlessApp.ALB.ALBFunctions_Hello_Generated::Hello" + } + }, + "ALBHelloALBPermission": { + "Type": "AWS::Lambda::Permission", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations" + }, + "Properties": { + "FunctionName": { + "Fn::GetAtt": [ + "ALBHello", + "Arn" + ] + }, + "Action": "lambda:InvokeFunction", + "Principal": "elasticloadbalancing.amazonaws.com" + } + }, + "ALBHelloALBTargetGroup": { + "Type": "AWS::ElasticLoadBalancingV2::TargetGroup", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations" + }, + "DependsOn": "ALBHelloALBPermission", + "Properties": { + "TargetType": "lambda", + "Targets": [ + { + "Id": { + "Fn::GetAtt": [ + "ALBHello", + "Arn" + ] + } + } + ] + } + }, + "ALBHelloALBListenerRule": { + "Type": "AWS::ElasticLoadBalancingV2::ListenerRule", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations" + }, + "Properties": { + "Priority": 1, + "Conditions": [ + { + "Field": "path-pattern", + "PathPatternConfig": { + "Values": [ + "/hello" + ] + } + } + ], + "Actions": [ + { + "Type": "forward", + "TargetGroupArn": { + "Ref": "ALBHelloALBTargetGroup" + } + } + ], + "ListenerArn": { + "Ref": "ALBTestListener" + } + } + }, + "ALBHealth": { + "Type": "AWS::Serverless::Function", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations", + "SyncedAlbResources": [ + "ALBHealthALBPermission", + "ALBHealthALBTargetGroup", + "ALBHealthALBListenerRule" + ] + }, + "Properties": { + "Runtime": "dotnet6", + "CodeUri": ".", + "MemorySize": 128, + "Timeout": 5, + "Policies": [ + "AWSLambdaBasicExecutionRole" + ], + "PackageType": "Zip", + "Handler": "TestServerlessApp.ALB::TestServerlessApp.ALB.ALBFunctions_Health_Generated::Health" + } + }, + "ALBHealthALBPermission": { + "Type": "AWS::Lambda::Permission", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations" + }, + "Properties": { + "FunctionName": { + "Fn::GetAtt": [ + "ALBHealth", + "Arn" + ] + }, + "Action": "lambda:InvokeFunction", + "Principal": "elasticloadbalancing.amazonaws.com" + } + }, + "ALBHealthALBTargetGroup": { + "Type": "AWS::ElasticLoadBalancingV2::TargetGroup", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations" + }, + "DependsOn": "ALBHealthALBPermission", + "Properties": { + "TargetType": "lambda", + "Targets": [ + { + "Id": { + "Fn::GetAtt": [ + "ALBHealth", + "Arn" + ] + } + } + ] + } + }, + "ALBHealthALBListenerRule": { + "Type": "AWS::ElasticLoadBalancingV2::ListenerRule", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations" + }, + "Properties": { + "Priority": 2, + "Conditions": [ + { + "Field": "path-pattern", + "PathPatternConfig": { + "Values": [ + "/health" + ] + } + } + ], + "Actions": [ + { + "Type": "forward", + "TargetGroupArn": { + "Ref": "ALBHealthALBTargetGroup" + } + } + ], + "ListenerArn": { + "Ref": "ALBTestListener" + } + } + }, + "ALBGreeting": { + "Type": "AWS::Serverless::Function", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations", + "SyncedAlbResources": [ + "ALBGreetingALBPermission", + "ALBGreetingALBTargetGroup", + "ALBGreetingALBListenerRule" + ] + }, + "Properties": { + "Runtime": "dotnet6", + "CodeUri": ".", + "MemorySize": 256, + "Timeout": 15, + "Policies": [ + "AWSLambdaBasicExecutionRole" + ], + "PackageType": "Zip", + "Handler": "TestServerlessApp.ALB::TestServerlessApp.ALB.ALBFunctions_Greeting_Generated::Greeting" + } + }, + "ALBGreetingALBPermission": { + "Type": "AWS::Lambda::Permission", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations" + }, + "Properties": { + "FunctionName": { + "Fn::GetAtt": [ + "ALBGreeting", + "Arn" + ] + }, + "Action": "lambda:InvokeFunction", + "Principal": "elasticloadbalancing.amazonaws.com" + } + }, + "ALBGreetingALBTargetGroup": { + "Type": "AWS::ElasticLoadBalancingV2::TargetGroup", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations" + }, + "DependsOn": "ALBGreetingALBPermission", + "Properties": { + "TargetType": "lambda", + "Targets": [ + { + "Id": { + "Fn::GetAtt": [ + "ALBGreeting", + "Arn" + ] + } + } + ] + } + }, + "ALBGreetingALBListenerRule": { + "Type": "AWS::ElasticLoadBalancingV2::ListenerRule", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations" + }, + "Properties": { + "Priority": 3, + "Conditions": [ + { + "Field": "path-pattern", + "PathPatternConfig": { + "Values": [ + "/greeting" + ] + } + } + ], + "Actions": [ + { + "Type": "forward", + "TargetGroupArn": { + "Ref": "ALBGreetingALBTargetGroup" + } + } + ], + "ListenerArn": { + "Ref": "ALBTestListener" + } + } + }, + "ALBCreateItem": { + "Type": "AWS::Serverless::Function", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations", + "SyncedAlbResources": [ + "ALBCreateItemALBPermission", + "ALBCreateItemALBTargetGroup", + "ALBCreateItemALBListenerRule" + ] + }, + "Properties": { + "Runtime": "dotnet6", + "CodeUri": ".", + "MemorySize": 256, + "Timeout": 15, + "Policies": [ + "AWSLambdaBasicExecutionRole" + ], + "PackageType": "Zip", + "Handler": "TestServerlessApp.ALB::TestServerlessApp.ALB.ALBFunctions_CreateItem_Generated::CreateItem" + } + }, + "ALBCreateItemALBPermission": { + "Type": "AWS::Lambda::Permission", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations" + }, + "Properties": { + "FunctionName": { + "Fn::GetAtt": [ + "ALBCreateItem", + "Arn" + ] + }, + "Action": "lambda:InvokeFunction", + "Principal": "elasticloadbalancing.amazonaws.com" + } + }, + "ALBCreateItemALBTargetGroup": { + "Type": "AWS::ElasticLoadBalancingV2::TargetGroup", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations" + }, + "DependsOn": "ALBCreateItemALBPermission", + "Properties": { + "TargetType": "lambda", + "Targets": [ + { + "Id": { + "Fn::GetAtt": [ + "ALBCreateItem", + "Arn" + ] + } + } + ] + } + }, + "ALBCreateItemALBListenerRule": { + "Type": "AWS::ElasticLoadBalancingV2::ListenerRule", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations" + }, + "Properties": { + "Priority": 4, + "Conditions": [ + { + "Field": "path-pattern", + "PathPatternConfig": { + "Values": [ + "/items" + ] + } + }, + { + "Field": "http-request-method", + "HttpRequestMethodConfig": { + "Values": [ + "POST" + ] + } + } + ], + "Actions": [ + { + "Type": "forward", + "TargetGroupArn": { + "Ref": "ALBCreateItemALBTargetGroup" + } + } + ], + "ListenerArn": { + "Ref": "ALBTestListener" + } + } + } + }, + "Outputs": { + "ALBDnsName": { + "Value": { + "Fn::GetAtt": [ + "ALBTestLoadBalancer", + "DNSName" + ] + }, + "Description": "ALB DNS Name for integration testing" + } + } +} \ No newline at end of file diff --git a/Libraries/test/TestServerlessApp.IntegrationTests/DeploymentScript.ps1 b/Libraries/test/TestServerlessApp.IntegrationTests/DeploymentScript.ps1 index 7a5cd5644..5bc3b87fc 100644 --- a/Libraries/test/TestServerlessApp.IntegrationTests/DeploymentScript.ps1 +++ b/Libraries/test/TestServerlessApp.IntegrationTests/DeploymentScript.ps1 @@ -68,6 +68,15 @@ try Write-Host "Added TestQueue resource to serverless.template" } + # Add TestS3Bucket resource to serverless.template for S3 event integration testing + # The source generator creates a Ref to TestS3Bucket but doesn't define the resource itself + $template = Get-Content $templatePath | Out-String | ConvertFrom-Json + if (-not $template.Resources.PSObject.Properties['TestS3Bucket']) { + $template.Resources | Add-Member -NotePropertyName "TestS3Bucket" -NotePropertyValue @{ Type = "AWS::S3::Bucket" } -Force + $template | ConvertTo-Json -Depth 100 | Set-Content $templatePath + Write-Host "Added TestS3Bucket resource to serverless.template" + } + dotnet restore Write-Host "Creating CloudFormation Stack $identifier, Architecture $arch, Runtime $runtime" dotnet lambda deploy-serverless --template-parameters "ArchitectureTypeParameter=$arch" diff --git a/Libraries/test/TestServerlessApp.IntegrationTests/FunctionUrlExample.cs b/Libraries/test/TestServerlessApp.IntegrationTests/FunctionUrlExample.cs new file mode 100644 index 000000000..b3f97929b --- /dev/null +++ b/Libraries/test/TestServerlessApp.IntegrationTests/FunctionUrlExample.cs @@ -0,0 +1,103 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using System; +using System.Linq; +using System.Net.Http; +using System.Threading.Tasks; +using Newtonsoft.Json.Linq; +using Xunit; + +namespace TestServerlessApp.IntegrationTests +{ + [Collection("Integration Tests")] + public class FunctionUrlExample + { + private readonly IntegrationTestContextFixture _fixture; + + public FunctionUrlExample(IntegrationTestContextFixture fixture) + { + _fixture = fixture; + } + + [Fact] + public async Task GetItems_WithCategory_ReturnsOkWithItems() + { + Assert.False(string.IsNullOrEmpty(_fixture.FunctionUrlPrefix), "FunctionUrlPrefix should not be empty. The Function URL was not discovered during setup."); + + var response = await GetWithRetryAsync($"{_fixture.FunctionUrlPrefix}?category=electronics"); + response.EnsureSuccessStatusCode(); + + var content = await response.Content.ReadAsStringAsync(); + var json = JObject.Parse(content); + + Assert.Equal("electronics", json["category"]?.ToString()); + Assert.NotNull(json["items"]); + var items = json["items"].ToObject(); + Assert.Equal(2, items.Length); + Assert.Contains("item1", items); + Assert.Contains("item2", items); + } + + [Fact] + public async Task GetItems_LogsToCloudWatch() + { + Assert.False(string.IsNullOrEmpty(_fixture.FunctionUrlPrefix), "FunctionUrlPrefix should not be empty. The Function URL was not discovered during setup."); + + var response = await GetWithRetryAsync($"{_fixture.FunctionUrlPrefix}?category=books"); + response.EnsureSuccessStatusCode(); + + var lambdaFunctionName = _fixture.LambdaFunctions + .FirstOrDefault(x => string.Equals(x.LogicalId, "TestServerlessAppFunctionUrlExampleGetItemsGenerated"))?.Name; + Assert.False(string.IsNullOrEmpty(lambdaFunctionName)); + + var logGroupName = _fixture.CloudWatchHelper.GetLogGroupName(lambdaFunctionName); + Assert.True( + await _fixture.CloudWatchHelper.MessageExistsInRecentLogEventsAsync("Getting items for category: books", logGroupName, logGroupName), + "Expected log message not found in CloudWatch logs"); + } + + [Fact] + public async Task VerifyFunctionUrlConfig_HasNoneAuthType() + { + var lambdaFunctionName = _fixture.LambdaFunctions + .FirstOrDefault(x => string.Equals(x.LogicalId, "TestServerlessAppFunctionUrlExampleGetItemsGenerated"))?.Name; + Assert.False(string.IsNullOrEmpty(lambdaFunctionName)); + + var functionUrlConfig = await _fixture.LambdaHelper.GetFunctionUrlConfigAsync(lambdaFunctionName); + Assert.NotNull(functionUrlConfig); + Assert.Equal("NONE", functionUrlConfig.AuthType.Value); + Assert.False(string.IsNullOrEmpty(functionUrlConfig.FunctionUrl), "Function URL should not be empty"); + Assert.Contains(".lambda-url.", functionUrlConfig.FunctionUrl); + } + + private async Task GetWithRetryAsync(string url) + { + const int maxAttempts = 10; + HttpResponseMessage response = null; + + for (var attempt = 0; attempt < maxAttempts; attempt++) + { + await Task.Delay(attempt * 1000); + try + { + response = await _fixture.HttpClient.GetAsync(url); + + // If we get a 403 Forbidden, it may be an eventual consistency issue + // with the Function URL permissions propagating. + if (response.StatusCode == System.Net.HttpStatusCode.Forbidden) + continue; + + break; + } + catch + { + if (attempt + 1 == maxAttempts) + throw; + } + } + + return response; + } + } +} diff --git a/Libraries/test/TestServerlessApp.IntegrationTests/IntegrationTestContextFixture.cs b/Libraries/test/TestServerlessApp.IntegrationTests/IntegrationTestContextFixture.cs index fb689cd7c..c4b139417 100644 --- a/Libraries/test/TestServerlessApp.IntegrationTests/IntegrationTestContextFixture.cs +++ b/Libraries/test/TestServerlessApp.IntegrationTests/IntegrationTestContextFixture.cs @@ -1,3 +1,6 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + using System; using System.Collections.Generic; using System.IO; @@ -25,17 +28,21 @@ public class IntegrationTestContextFixture : IAsyncLifetime public readonly LambdaHelper LambdaHelper; public readonly CloudWatchHelper CloudWatchHelper; + public readonly S3Helper S3HelperInstance; public readonly HttpClient HttpClient; public string RestApiUrlPrefix; public string HttpApiUrlPrefix; + public string FunctionUrlPrefix; public string TestQueueARN; + public string TestS3BucketName; public List LambdaFunctions; public IntegrationTestContextFixture() { _cloudFormationHelper = new CloudFormationHelper(new AmazonCloudFormationClient(Amazon.RegionEndpoint.USWest2)); _s3Helper = new S3Helper(new AmazonS3Client(Amazon.RegionEndpoint.USWest2)); + S3HelperInstance = _s3Helper; LambdaHelper = new LambdaHelper(new AmazonLambdaClient(Amazon.RegionEndpoint.USWest2)); CloudWatchHelper = new CloudWatchHelper(new AmazonCloudWatchLogsClient(Amazon.RegionEndpoint.USWest2)); HttpClient = new HttpClient(); @@ -77,16 +84,32 @@ public async Task InitializeAsync() Console.WriteLine($"[IntegrationTest] TestQueue URL: {queueUrl}"); Assert.False(string.IsNullOrEmpty(queueUrl), $"CloudFormation resource 'TestQueue' was not found in stack '{_stackName}'."); TestQueueARN = ConvertSqsUrlToArn(queueUrl); + + // Get the S3 bucket name from the physical resource ID + TestS3BucketName = await _cloudFormationHelper.GetResourcePhysicalIdAsync(_stackName, "TestS3Bucket"); + Console.WriteLine($"[IntegrationTest] TestS3Bucket: {TestS3BucketName}"); + Assert.False(string.IsNullOrEmpty(TestS3BucketName), $"CloudFormation resource 'TestS3Bucket' was not found in stack '{_stackName}'."); + LambdaFunctions = await LambdaHelper.FilterByCloudFormationStackAsync(_stackName); Console.WriteLine($"[IntegrationTest] Found {LambdaFunctions.Count} Lambda functions: {string.Join(", ", LambdaFunctions.Select(f => f.Name ?? "(null)"))}"); Assert.True(await _s3Helper.BucketExistsAsync(_bucketName), $"S3 bucket {_bucketName} should exist"); - Assert.Equal(36, LambdaFunctions.Count); + Assert.Equal(38, LambdaFunctions.Count); Assert.False(string.IsNullOrEmpty(RestApiUrlPrefix), "RestApiUrlPrefix should not be empty"); Assert.False(string.IsNullOrEmpty(HttpApiUrlPrefix), "HttpApiUrlPrefix should not be empty"); await LambdaHelper.WaitTillNotPending(LambdaFunctions.Where(x => x.Name != null).Select(x => x.Name).ToList()); + // Discover the Function URL for the FunctionUrlExample function + var functionUrlLambdaName = LambdaFunctions + .FirstOrDefault(x => string.Equals(x.LogicalId, "TestServerlessAppFunctionUrlExampleGetItemsGenerated"))?.Name; + if (!string.IsNullOrEmpty(functionUrlLambdaName)) + { + var functionUrlConfig = await LambdaHelper.GetFunctionUrlConfigAsync(functionUrlLambdaName); + FunctionUrlPrefix = functionUrlConfig.FunctionUrl.TrimEnd('/'); + Console.WriteLine($"[IntegrationTest] FunctionUrlPrefix: {FunctionUrlPrefix}"); + } + // Wait an additional 10 seconds for any other eventually consistency state to finish up. await Task.Delay(10000); } diff --git a/Libraries/test/TestServerlessApp.IntegrationTests/S3EventNotification.cs b/Libraries/test/TestServerlessApp.IntegrationTests/S3EventNotification.cs new file mode 100644 index 000000000..d9758ae00 --- /dev/null +++ b/Libraries/test/TestServerlessApp.IntegrationTests/S3EventNotification.cs @@ -0,0 +1,53 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using System.Linq; +using System.Threading.Tasks; +using Amazon.S3; +using Xunit; + +namespace TestServerlessApp.IntegrationTests +{ + [Collection("Integration Tests")] + public class S3EventNotification + { + private readonly IntegrationTestContextFixture _fixture; + + public S3EventNotification(IntegrationTestContextFixture fixture) + { + _fixture = fixture; + } + + [Fact] + public async Task VerifyS3EventNotificationConfiguration() + { + // Verify the Lambda function exists in the stack + var lambdaFunction = _fixture.LambdaFunctions + .FirstOrDefault(x => string.Equals(x.LogicalId, "S3EventHandler")); + Assert.NotNull(lambdaFunction); + Assert.NotNull(lambdaFunction.Name); + + // Verify S3 bucket notification is configured correctly + var notificationConfig = await _fixture.S3HelperInstance + .GetBucketNotificationAsync(_fixture.TestS3BucketName); + + var lambdaConfigs = notificationConfig.LambdaFunctionConfigurations; + Assert.Single(lambdaConfigs); + + var config = lambdaConfigs.First(); + + // Verify the notification points to the correct Lambda function ARN + Assert.Contains(lambdaFunction.Name, config.FunctionArn); + + // Verify the event type is s3:ObjectCreated:* + Assert.Single(config.Events); + Assert.Equal(EventType.ObjectCreatedAll, config.Events.First()); + + // Verify the suffix filter is .json + var filterRules = config.Filter.S3KeyFilter.FilterRules; + Assert.Single(filterRules); + var suffixRule = filterRules.First(r => string.Equals(r.Name, "suffix", System.StringComparison.OrdinalIgnoreCase)); + Assert.Equal(".json", suffixRule.Value); + } + } +} diff --git a/Libraries/test/TestServerlessApp.NET8/serverless.template b/Libraries/test/TestServerlessApp.NET8/serverless.template index c139ace2b..03b6cb0d5 100644 --- a/Libraries/test/TestServerlessApp.NET8/serverless.template +++ b/Libraries/test/TestServerlessApp.NET8/serverless.template @@ -1,7 +1,7 @@ { "AWSTemplateFormatVersion": "2010-09-09", "Transform": "AWS::Serverless-2016-10-31", - "Description": "This template is partially managed by Amazon.Lambda.Annotations (v1.10.0.0).", + "Description": "This template is partially managed by Amazon.Lambda.Annotations (v1.13.0.0).", "Resources": { "TestServerlessAppNET8FunctionsToUpperGenerated": { "Type": "AWS::Serverless::Function", diff --git a/Libraries/test/TestServerlessApp/ALBEventExamples/ValidALBEvents.cs.txt b/Libraries/test/TestServerlessApp/ALBEventExamples/ValidALBEvents.cs.txt new file mode 100644 index 000000000..1becbac0e --- /dev/null +++ b/Libraries/test/TestServerlessApp/ALBEventExamples/ValidALBEvents.cs.txt @@ -0,0 +1,42 @@ +using Amazon.Lambda.Annotations; +using Amazon.Lambda.Annotations.ALB; +using Amazon.Lambda.ApplicationLoadBalancerEvents; +using Amazon.Lambda.Core; +using System.Collections.Generic; +using System.Threading.Tasks; + +namespace TestServerlessApp.ALBEventExamples +{ + // This file represents valid usage of the ALBApiAttribute. This is added as .txt file since we do not want to deploy these functions during our integration tests. + // This file is only sent as input to the source generator unit tests. + // Refer to VerifyValidALBEvents unit test. + + public class ValidALBEvents + { + [LambdaFunction(ResourceName = "ALBHelloWorld")] + [ALBApi("arn:aws:elasticloadbalancing:us-east-1:123456789012:listener/app/my-alb/abc/def", "/hello", 1)] + public ApplicationLoadBalancerResponse Hello(ApplicationLoadBalancerRequest request, ILambdaContext context) + { + return new ApplicationLoadBalancerResponse + { + StatusCode = 200, + StatusDescription = "200 OK", + Headers = new Dictionary { { "Content-Type", "text/plain" } }, + Body = "Hello from ALB Lambda!" + }; + } + + [LambdaFunction(ResourceName = "ALBWithOptions")] + [ALBApi("@MyALBListener", "/api/*", 5, MultiValueHeaders = true, HostHeader = "api.example.com", HttpMethod = "POST")] + public async Task HandleRequest(ApplicationLoadBalancerRequest request, ILambdaContext context) + { + await Task.CompletedTask; + return new ApplicationLoadBalancerResponse + { + StatusCode = 200, + StatusDescription = "200 OK", + Body = "OK" + }; + } + } +} diff --git a/Libraries/test/TestServerlessApp/FunctionUrlExample.cs b/Libraries/test/TestServerlessApp/FunctionUrlExample.cs new file mode 100644 index 000000000..4909c768e --- /dev/null +++ b/Libraries/test/TestServerlessApp/FunctionUrlExample.cs @@ -0,0 +1,20 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using Amazon.Lambda.Annotations; +using Amazon.Lambda.Annotations.APIGateway; +using Amazon.Lambda.Core; + +namespace TestServerlessApp +{ + public class FunctionUrlExample + { + [LambdaFunction(PackageType = LambdaPackageType.Image)] + [FunctionUrl(AuthType = FunctionUrlAuthType.NONE)] + public IHttpResult GetItems([FromQuery] string category, ILambdaContext context) + { + context.Logger.LogLine($"Getting items for category: {category}"); + return HttpResults.Ok(new { items = new[] { "item1", "item2" }, category }); + } + } +} diff --git a/Libraries/test/TestServerlessApp/S3EventExamples/S3EventProcessing.cs b/Libraries/test/TestServerlessApp/S3EventExamples/S3EventProcessing.cs new file mode 100644 index 000000000..0ee914c76 --- /dev/null +++ b/Libraries/test/TestServerlessApp/S3EventExamples/S3EventProcessing.cs @@ -0,0 +1,21 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using Amazon.Lambda.Annotations; +using Amazon.Lambda.Annotations.S3; +using Amazon.Lambda.Core; +using Amazon.Lambda.S3Events; +using System; + +namespace TestServerlessApp.S3EventExamples +{ + public class S3EventProcessing + { + [LambdaFunction(ResourceName = "S3EventHandler", Policies = "AWSLambdaBasicExecutionRole,AmazonS3ReadOnlyAccess", PackageType = LambdaPackageType.Image)] + [S3Event("@TestS3Bucket", Events = "s3:ObjectCreated:*", FilterSuffix = ".json")] + public void ProcessS3Event(S3Event evnt) + { + Console.WriteLine($"Received S3 event with {evnt.Records.Count} records"); + } + } +} diff --git a/Libraries/test/TestServerlessApp/S3EventExamples/ValidS3Events.cs.txt b/Libraries/test/TestServerlessApp/S3EventExamples/ValidS3Events.cs.txt new file mode 100644 index 000000000..9a5da1f8e --- /dev/null +++ b/Libraries/test/TestServerlessApp/S3EventExamples/ValidS3Events.cs.txt @@ -0,0 +1,38 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using Amazon.Lambda.Annotations; +using Amazon.Lambda.Annotations.S3; +using Amazon.Lambda.S3Events; +using System; +using System.Threading.Tasks; + +namespace TestServerlessApp.S3EventExamples +{ + // This file represents valid usage of the S3EventAttribute. This is added as .txt file since we do not want to deploy these functions during our integration tests. + // This file is only sent as input to the source generator unit tests. + + public class ValidS3Events + { + [LambdaFunction(PackageType = LambdaPackageType.Image)] + [S3Event("@MyBucket")] + public void ProcessS3Event(S3Event evnt) + { + Console.WriteLine($"Event processed: {evnt}"); + } + + [LambdaFunction(PackageType = LambdaPackageType.Image)] + [S3Event("@MyBucket", Events = "s3:ObjectCreated:*;s3:ObjectRemoved:*", FilterPrefix = "uploads/", FilterSuffix = ".jpg")] + public async Task ProcessS3EventWithFilters(S3Event evnt) + { + await Console.Out.WriteLineAsync($"Event processed: {evnt}"); + } + + [LambdaFunction(PackageType = LambdaPackageType.Image)] + [S3Event("@ImageBucket", ResourceName = "ImageBucketEvent", Enabled = false)] + public void ProcessS3EventDisabled(S3Event evnt) + { + Console.WriteLine($"Event processed: {evnt}"); + } + } +} diff --git a/Libraries/test/TestServerlessApp/TestServerlessApp.csproj b/Libraries/test/TestServerlessApp/TestServerlessApp.csproj index 921e3d372..83d7cf89e 100644 --- a/Libraries/test/TestServerlessApp/TestServerlessApp.csproj +++ b/Libraries/test/TestServerlessApp/TestServerlessApp.csproj @@ -27,6 +27,7 @@ + diff --git a/Libraries/test/TestServerlessApp/aws-lambda-tools-defaults.json b/Libraries/test/TestServerlessApp/aws-lambda-tools-defaults.json index 0b96350ff..71f6d708b 100644 --- a/Libraries/test/TestServerlessApp/aws-lambda-tools-defaults.json +++ b/Libraries/test/TestServerlessApp/aws-lambda-tools-defaults.json @@ -13,7 +13,7 @@ "template": "serverless.template", "template-parameters": "", "docker-host-build-output-dir": "./bin/Release/lambda-publish", - "s3-bucket": "test-serverless-app", - "stack-name": "test-serverless-app", - "function-architecture": "x86_64" -} \ No newline at end of file +"s3-bucket" : "test-serverless-app-535afbc5", +"stack-name" : "test-serverless-app-535afbc5", +"function-architecture" : "x86_64" +} diff --git a/Libraries/test/TestServerlessApp/serverless.template b/Libraries/test/TestServerlessApp/serverless.template index a0bf929eb..b5753ecfd 100644 --- a/Libraries/test/TestServerlessApp/serverless.template +++ b/Libraries/test/TestServerlessApp/serverless.template @@ -1,11 +1,8 @@ { "AWSTemplateFormatVersion": "2010-09-09", "Transform": "AWS::Serverless-2016-10-31", - "Description": "This template is partially managed by Amazon.Lambda.Annotations (v1.10.0.0).", + "Description": "This template is partially managed by Amazon.Lambda.Annotations (v1.13.0.0).", "Resources": { - "TestQueue": { - "Type": "AWS::SQS::Queue" - }, "AnnotationsHttpApi": { "Type": "AWS::Serverless::HttpApi", "Metadata": { @@ -801,6 +798,30 @@ } } }, + "TestServerlessAppFunctionUrlExampleGetItemsGenerated": { + "Type": "AWS::Serverless::Function", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations", + "SyncedFunctionUrlConfig": true + }, + "Properties": { + "MemorySize": 512, + "Timeout": 30, + "Policies": [ + "AWSLambdaBasicExecutionRole" + ], + "PackageType": "Image", + "ImageUri": ".", + "ImageConfig": { + "Command": [ + "TestServerlessApp::TestServerlessApp.FunctionUrlExample_GetItems_Generated::GetItems" + ] + }, + "FunctionUrlConfig": { + "AuthType": "NONE" + } + } + }, "GreeterSayHello": { "Type": "AWS::Serverless::Function", "Metadata": { @@ -991,6 +1012,60 @@ } } }, + "S3EventHandler": { + "Type": "AWS::Serverless::Function", + "Metadata": { + "Tool": "Amazon.Lambda.Annotations", + "SyncedEvents": [ + "TestS3Bucket" + ], + "SyncedEventProperties": { + "TestS3Bucket": [ + "Bucket.Ref", + "Events", + "Filter.S3Key.Rules" + ] + } + }, + "Properties": { + "MemorySize": 512, + "Timeout": 30, + "Policies": [ + "AWSLambdaBasicExecutionRole", + "AmazonS3ReadOnlyAccess" + ], + "PackageType": "Image", + "ImageUri": ".", + "ImageConfig": { + "Command": [ + "TestServerlessApp::TestServerlessApp.S3EventExamples.S3EventProcessing_ProcessS3Event_Generated::ProcessS3Event" + ] + }, + "Events": { + "TestS3Bucket": { + "Type": "S3", + "Properties": { + "Events": [ + "s3:ObjectCreated:*" + ], + "Filter": { + "S3Key": { + "Rules": [ + { + "Name": "suffix", + "Value": ".json" + } + ] + } + }, + "Bucket": { + "Ref": "TestS3Bucket" + } + } + } + } + } + }, "SimpleCalculatorAdd": { "Type": "AWS::Serverless::Function", "Metadata": { diff --git a/Tools/LambdaTestTool-v2/tests/Amazon.Lambda.TestTool.IntegrationTests/BaseApiGatewayTest.cs b/Tools/LambdaTestTool-v2/tests/Amazon.Lambda.TestTool.IntegrationTests/BaseApiGatewayTest.cs index f277ffa2a..c9998141c 100644 --- a/Tools/LambdaTestTool-v2/tests/Amazon.Lambda.TestTool.IntegrationTests/BaseApiGatewayTest.cs +++ b/Tools/LambdaTestTool-v2/tests/Amazon.Lambda.TestTool.IntegrationTests/BaseApiGatewayTest.cs @@ -48,6 +48,7 @@ protected async Task CleanupAsync() CancellationTokenSource.Dispose(); CancellationTokenSource = new CancellationTokenSource(); } + Environment.SetEnvironmentVariable("APIGATEWAY_EMULATOR_ROUTE_CONFIG", null); } protected async Task StartTestToolProcessAsync(ApiGatewayEmulatorMode apiGatewayMode, string routeName, int lambdaPort, int apiGatewayPort, CancellationTokenSource cancellationTokenSource, string httpMethod = "POST") diff --git a/Tools/LambdaTestTool-v2/tests/Amazon.Lambda.TestTool.IntegrationTests/SQSEventSourceTests.cs b/Tools/LambdaTestTool-v2/tests/Amazon.Lambda.TestTool.IntegrationTests/SQSEventSourceTests.cs index fc4aa1882..f02ee3abe 100644 --- a/Tools/LambdaTestTool-v2/tests/Amazon.Lambda.TestTool.IntegrationTests/SQSEventSourceTests.cs +++ b/Tools/LambdaTestTool-v2/tests/Amazon.Lambda.TestTool.IntegrationTests/SQSEventSourceTests.cs @@ -183,7 +183,7 @@ public async Task ProcessMessagesFromMultipleEventSources() await sqsClient.SendMessageAsync(queueUrl2, "MessageFromQueue2"); var startTime = DateTime.UtcNow; - while (listOfProcessedMessages.Count == 0 && DateTime.UtcNow < startTime.AddMinutes(2)) + while (listOfProcessedMessages.Count < 2 && DateTime.UtcNow < startTime.AddMinutes(2)) { Assert.False(lambdaTask.IsFaulted, "Lambda function failed: " + lambdaTask.Exception?.ToString()); await Task.Delay(500); diff --git a/Tools/LambdaTestTool-v2/tests/Amazon.Lambda.TestTool.Tests.Common/Helpers/TestHelpers.cs b/Tools/LambdaTestTool-v2/tests/Amazon.Lambda.TestTool.Tests.Common/Helpers/TestHelpers.cs index 79c618d48..9def24fe2 100644 --- a/Tools/LambdaTestTool-v2/tests/Amazon.Lambda.TestTool.Tests.Common/Helpers/TestHelpers.cs +++ b/Tools/LambdaTestTool-v2/tests/Amazon.Lambda.TestTool.Tests.Common/Helpers/TestHelpers.cs @@ -1,6 +1,9 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 +using System.Net; +using System.Net.Sockets; + namespace Amazon.Lambda.TestTool.Tests.Common.Helpers; public static class TestHelpers @@ -39,16 +42,22 @@ public static async Task SendRequest(string url) } } - private static int _maxLambdaRuntimePort = 6000; - private static int _maxApiGatewayPort = 9000; + private static int GetFreePort() + { + var listener = new TcpListener(IPAddress.Loopback, 0); + listener.Start(); + var port = ((IPEndPoint)listener.LocalEndpoint).Port; + listener.Stop(); + return port; + } public static int GetNextLambdaRuntimePort() { - return Interlocked.Increment(ref _maxLambdaRuntimePort); + return GetFreePort(); } public static int GetNextApiGatewayPort() { - return Interlocked.Increment(ref _maxApiGatewayPort); + return GetFreePort(); } } diff --git a/Tools/LambdaTestTool-v2/tests/Amazon.Lambda.TestTool.UnitTests/RuntimeApiTests.cs b/Tools/LambdaTestTool-v2/tests/Amazon.Lambda.TestTool.UnitTests/RuntimeApiTests.cs index 998460411..cb2afa01a 100644 --- a/Tools/LambdaTestTool-v2/tests/Amazon.Lambda.TestTool.UnitTests/RuntimeApiTests.cs +++ b/Tools/LambdaTestTool-v2/tests/Amazon.Lambda.TestTool.UnitTests/RuntimeApiTests.cs @@ -39,6 +39,8 @@ public async Task AddEventToDataStore() var testToolProcess = TestToolProcess.Startup(options, cancellationTokenSource.Token); try { + Assert.True(await TestHelpers.WaitForApiToStartAsync($"{testToolProcess.ServiceUrl}/lambda-runtime-api/healthcheck")); + var lambdaClient = ConstructLambdaServiceClient(testToolProcess.ServiceUrl); var invokeFunction = new InvokeRequest { @@ -92,6 +94,8 @@ public async Task InvokeRequestResponse() var testToolProcess = TestToolProcess.Startup(options, cancellationTokenSource.Token); try { + Assert.True(await TestHelpers.WaitForApiToStartAsync($"{testToolProcess.ServiceUrl}/lambda-runtime-api/healthcheck")); + var handler = (string input, ILambdaContext context) => { Thread.Sleep(1000); // Add a sleep to prove the LambdaRuntimeApi waited for the completion. diff --git a/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool.BlazorTester/Amazon.Lambda.TestTool.BlazorTester.csproj b/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool.BlazorTester/Amazon.Lambda.TestTool.BlazorTester.csproj index 878c49507..345380f68 100644 --- a/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool.BlazorTester/Amazon.Lambda.TestTool.BlazorTester.csproj +++ b/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool.BlazorTester/Amazon.Lambda.TestTool.BlazorTester.csproj @@ -6,7 +6,7 @@ Exe A tool to help debug and test your .NET Core AWS Lambda functions locally. Latest - 0.17.0 + 0.17.1 AWS .NET Lambda Test Tool Apache 2 AWS;Amazon;Lambda diff --git a/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool.BlazorTester/Amazon.Lambda.TestTool.BlazorTester10_0-pack.csproj b/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool.BlazorTester/Amazon.Lambda.TestTool.BlazorTester10_0-pack.csproj index f18756bce..55fbe0b02 100644 --- a/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool.BlazorTester/Amazon.Lambda.TestTool.BlazorTester10_0-pack.csproj +++ b/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool.BlazorTester/Amazon.Lambda.TestTool.BlazorTester10_0-pack.csproj @@ -5,7 +5,7 @@ Exe A tool to help debug and test your .NET 10.0 AWS Lambda functions locally. - 0.17.0 + 0.17.1 AWS .NET Lambda Test Tool Apache 2 AWS;Amazon;Lambda diff --git a/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool.BlazorTester/Amazon.Lambda.TestTool.BlazorTester80-pack.csproj b/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool.BlazorTester/Amazon.Lambda.TestTool.BlazorTester80-pack.csproj index 45afdfeb1..422742f07 100644 --- a/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool.BlazorTester/Amazon.Lambda.TestTool.BlazorTester80-pack.csproj +++ b/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool.BlazorTester/Amazon.Lambda.TestTool.BlazorTester80-pack.csproj @@ -5,7 +5,7 @@ Exe A tool to help debug and test your .NET 8.0 AWS Lambda functions locally. - 0.17.0 + 0.17.1 AWS .NET Lambda Test Tool Apache 2 AWS;Amazon;Lambda diff --git a/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool.BlazorTester/Amazon.Lambda.TestTool.BlazorTester90-pack.csproj b/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool.BlazorTester/Amazon.Lambda.TestTool.BlazorTester90-pack.csproj index e4a7cc450..d656e40be 100644 --- a/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool.BlazorTester/Amazon.Lambda.TestTool.BlazorTester90-pack.csproj +++ b/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool.BlazorTester/Amazon.Lambda.TestTool.BlazorTester90-pack.csproj @@ -5,7 +5,7 @@ Exe A tool to help debug and test your .NET 9.0 AWS Lambda functions locally. - 0.17.0 + 0.17.1 AWS .NET Lambda Test Tool Apache 2 AWS;Amazon;Lambda diff --git a/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool/TestToolStartup.cs b/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool/TestToolStartup.cs index 9693d64e4..87353be34 100644 --- a/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool/TestToolStartup.cs +++ b/Tools/LambdaTestTool/src/Amazon.Lambda.TestTool/TestToolStartup.cs @@ -10,7 +10,6 @@ namespace Amazon.Lambda.TestTool { public class TestToolStartup { - private static bool shouldDisableLogs; public class RunConfiguration { @@ -37,7 +36,7 @@ public static void Startup(string productName, Action try { var commandOptions = CommandLineOptions.Parse(args); - shouldDisableLogs = Utils.ShouldDisableLogs(commandOptions); + var shouldDisableLogs = Utils.ShouldDisableLogs(commandOptions); if (!shouldDisableLogs) Utils.PrintToolTitle(productName); @@ -76,7 +75,7 @@ public static void Startup(string productName, Action if (commandOptions.NoUI) { - ExecuteWithNoUi(localLambdaOptions, commandOptions, lambdaAssemblyDirectory, runConfiguration); + ExecuteWithNoUi(localLambdaOptions, commandOptions, lambdaAssemblyDirectory, runConfiguration, shouldDisableLogs); } else { @@ -118,16 +117,16 @@ public static void Startup(string productName, Action } - public static void ExecuteWithNoUi(LocalLambdaOptions localLambdaOptions, CommandLineOptions commandOptions, string lambdaAssemblyDirectory, RunConfiguration runConfiguration) + public static void ExecuteWithNoUi(LocalLambdaOptions localLambdaOptions, CommandLineOptions commandOptions, string lambdaAssemblyDirectory, RunConfiguration runConfiguration, bool shouldDisableLogs) { if (!shouldDisableLogs) runConfiguration.OutputWriter.WriteLine("Executing Lambda function without web interface"); var lambdaProjectDirectory = Utils.FindLambdaProjectDirectory(lambdaAssemblyDirectory); - string configFile = DetermineConfigFile(commandOptions, lambdaAssemblyDirectory: lambdaAssemblyDirectory, lambdaProjectDirectory: lambdaProjectDirectory); - LambdaConfigInfo configInfo = LoadLambdaConfigInfo(configFile, commandOptions, lambdaAssemblyDirectory: lambdaAssemblyDirectory, lambdaProjectDirectory: lambdaProjectDirectory, runConfiguration); - LambdaFunction lambdaFunction = LoadLambdaFunction(configInfo, localLambdaOptions, commandOptions, lambdaAssemblyDirectory: lambdaAssemblyDirectory, lambdaProjectDirectory: lambdaProjectDirectory, runConfiguration); + string configFile = DetermineConfigFile(commandOptions, lambdaAssemblyDirectory: lambdaAssemblyDirectory, lambdaProjectDirectory: lambdaProjectDirectory, shouldDisableLogs: shouldDisableLogs); + LambdaConfigInfo configInfo = LoadLambdaConfigInfo(configFile, commandOptions, lambdaAssemblyDirectory: lambdaAssemblyDirectory, lambdaProjectDirectory: lambdaProjectDirectory, runConfiguration, shouldDisableLogs: shouldDisableLogs); + LambdaFunction lambdaFunction = LoadLambdaFunction(configInfo, localLambdaOptions, commandOptions, lambdaAssemblyDirectory: lambdaAssemblyDirectory, lambdaProjectDirectory: lambdaProjectDirectory, runConfiguration, shouldDisableLogs: shouldDisableLogs); - string payload = DeterminePayload(localLambdaOptions, commandOptions, lambdaAssemblyDirectory: lambdaAssemblyDirectory, lambdaProjectDirectory: lambdaProjectDirectory, runConfiguration); + string payload = DeterminePayload(localLambdaOptions, commandOptions, lambdaAssemblyDirectory: lambdaAssemblyDirectory, lambdaProjectDirectory: lambdaProjectDirectory, runConfiguration, shouldDisableLogs: shouldDisableLogs); var awsProfile = commandOptions.AWSProfile ?? configInfo.AWSProfile; if (!string.IsNullOrEmpty(awsProfile)) @@ -166,7 +165,7 @@ public static void ExecuteWithNoUi(LocalLambdaOptions localLambdaOptions, Comman Function = lambdaFunction }; - ExecuteRequest(request, localLambdaOptions, runConfiguration); + ExecuteRequest(request, localLambdaOptions, runConfiguration, shouldDisableLogs); if (runConfiguration.Mode == RunConfiguration.RunMode.Normal && commandOptions.PauseExit) @@ -176,7 +175,7 @@ public static void ExecuteWithNoUi(LocalLambdaOptions localLambdaOptions, Comman } } - private static string DetermineConfigFile(CommandLineOptions commandOptions, string lambdaAssemblyDirectory, string lambdaProjectDirectory) + private static string DetermineConfigFile(CommandLineOptions commandOptions, string lambdaAssemblyDirectory, string lambdaProjectDirectory, bool shouldDisableLogs) { string configFile = null; if (string.IsNullOrEmpty(commandOptions.ConfigFile)) @@ -199,7 +198,7 @@ private static string DetermineConfigFile(CommandLineOptions commandOptions, str return configFile; } - private static LambdaConfigInfo LoadLambdaConfigInfo(string configFile, CommandLineOptions commandOptions, string lambdaAssemblyDirectory, string lambdaProjectDirectory, RunConfiguration runConfiguration) + private static LambdaConfigInfo LoadLambdaConfigInfo(string configFile, CommandLineOptions commandOptions, string lambdaAssemblyDirectory, string lambdaProjectDirectory, RunConfiguration runConfiguration, bool shouldDisableLogs) { LambdaConfigInfo configInfo; if (configFile != null) @@ -226,7 +225,7 @@ private static LambdaConfigInfo LoadLambdaConfigInfo(string configFile, CommandL return configInfo; } - private static LambdaFunction LoadLambdaFunction(LambdaConfigInfo configInfo, LocalLambdaOptions localLambdaOptions, CommandLineOptions commandOptions, string lambdaAssemblyDirectory, string lambdaProjectDirectory, RunConfiguration runConfiguration) + private static LambdaFunction LoadLambdaFunction(LambdaConfigInfo configInfo, LocalLambdaOptions localLambdaOptions, CommandLineOptions commandOptions, string lambdaAssemblyDirectory, string lambdaProjectDirectory, RunConfiguration runConfiguration, bool shouldDisableLogs) { // If no function handler was explicitly set and there is only one function defined in the config file then assume the user wants to debug that function. var functionHandler = commandOptions.FunctionHandler; @@ -264,7 +263,7 @@ private static LambdaFunction LoadLambdaFunction(LambdaConfigInfo configInfo, Lo return lambdaFunction; } - private static string DeterminePayload(LocalLambdaOptions localLambdaOptions, CommandLineOptions commandOptions, string lambdaAssemblyDirectory, string lambdaProjectDirectory, RunConfiguration runConfiguration) + private static string DeterminePayload(LocalLambdaOptions localLambdaOptions, CommandLineOptions commandOptions, string lambdaAssemblyDirectory, string lambdaProjectDirectory, RunConfiguration runConfiguration, bool shouldDisableLogs) { var payload = commandOptions.Payload; @@ -346,7 +345,7 @@ private static string DeterminePayload(LocalLambdaOptions localLambdaOptions, Co return payload; } - private static void ExecuteRequest(ExecutionRequest request, LocalLambdaOptions localLambdaOptions, RunConfiguration runConfiguration) + private static void ExecuteRequest(ExecutionRequest request, LocalLambdaOptions localLambdaOptions, RunConfiguration runConfiguration, bool shouldDisableLogs) { try {