diff --git a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerPresignedUrlDownloadIntegrationTest.java b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerPresignedUrlDownloadIntegrationTest.java index 0fa70b78243..9dbe520dc75 100644 --- a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerPresignedUrlDownloadIntegrationTest.java +++ b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerPresignedUrlDownloadIntegrationTest.java @@ -41,6 +41,7 @@ import software.amazon.awssdk.testutils.RandomTempFile; import software.amazon.awssdk.transfer.s3.model.CompletedDownload; import software.amazon.awssdk.transfer.s3.model.CompletedFileDownload; +import software.amazon.awssdk.transfer.s3.model.Download; import software.amazon.awssdk.transfer.s3.model.FileDownload; import software.amazon.awssdk.transfer.s3.model.PresignedDownloadFileRequest; import software.amazon.awssdk.transfer.s3.model.PresignedDownloadRequest; @@ -109,6 +110,75 @@ void downloadWithPresignedUrl_toBytes_shouldReturnCorrectData(S3TransferManager assertThat(completed.result().asByteArray()).hasSize(objSize); } + static Stream progressTestCases() { + return Stream.of( + Arguments.of("multipart", tmJava, LARGE_KEY, null, LARGE_OBJ_SIZE), + Arguments.of("multipart", tmJava, LARGE_KEY, "bytes=0-1048575", 1048576), + Arguments.of("nonMultipart", tmNonMultipartJava, LARGE_KEY, null, LARGE_OBJ_SIZE), + Arguments.of("nonMultipart", tmNonMultipartJava, LARGE_KEY, "bytes=0-1048575", 1048576) + ); + } + + @ParameterizedTest(name = "downloadFileWithPresignedUrl_progress_{0}_range={3}") + @MethodSource("progressTestCases") + void downloadFileWithPresignedUrl_progressTracking(String tmType, S3TransferManager tm, String key, + String range, int expectedSize) throws Exception { + Path downloadPath = RandomTempFile.randomUncreatedFile().toPath(); + + PresignedUrlDownloadRequest.Builder requestBuilder = PresignedUrlDownloadRequest.builder() + .presignedUrl(createPresignedRequest(key).url()); + if (range != null) { + requestBuilder.range(range); + } + + FileDownload download = tm.downloadFileWithPresignedUrl( + PresignedDownloadFileRequest.builder() + .presignedUrlDownloadRequest(requestBuilder.build()) + .destination(downloadPath) + .addTransferListener(LoggingTransferListener.create()) + .build()); + + download.completionFuture().join(); + + // Verify progress tracking worked - totalBytes is set correctly + assertThat(download.progress().snapshot().totalBytes()).isPresent(); + assertThat(download.progress().snapshot().totalBytes().getAsLong()).isEqualTo(expectedSize); + + // Verify transferredBytes reached expectedSize + assertThat(download.progress().snapshot().transferredBytes()).isEqualTo(expectedSize); + + // Verify file size matches expected + assertThat(downloadPath.toFile().length()).isEqualTo(expectedSize); + } + + @ParameterizedTest(name = "downloadWithPresignedUrl_toBytes_progress_{0}_range={3}") + @MethodSource("progressTestCases") + void downloadWithPresignedUrl_toBytes_progressTracking(String tmType, S3TransferManager tm, String key, + String range, int expectedSize) throws Exception { + + PresignedUrlDownloadRequest.Builder requestBuilder = PresignedUrlDownloadRequest.builder() + .presignedUrl(createPresignedRequest(key).url()); + if (range != null) { + requestBuilder.range(range); + } + + Download> download = tm.downloadWithPresignedUrl( + PresignedDownloadRequest.>builder() + .presignedUrlDownloadRequest(requestBuilder.build()) + .responseTransformer(AsyncResponseTransformer.toBytes()) + .addTransferListener(LoggingTransferListener.create()) + .build()); + + CompletedDownload> completed = download.completionFuture().join(); + + assertThat(download.progress().snapshot().totalBytes()).isPresent(); + assertThat(download.progress().snapshot().totalBytes().getAsLong()).isEqualTo(expectedSize); + + assertThat(download.progress().snapshot().transferredBytes()).isEqualTo(expectedSize); + + assertThat(completed.result().asByteArray()).hasSize(expectedSize); + } + private static PresignedDownloadFileRequest createFileDownloadRequest(String key, Path destination) { return PresignedDownloadFileRequest.builder() .presignedUrlDownloadRequest(PresignedUrlDownloadRequest.builder() diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/GenericS3TransferManager.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/GenericS3TransferManager.java index 8d8825c67ea..8e3852a4f68 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/GenericS3TransferManager.java +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/GenericS3TransferManager.java @@ -613,8 +613,9 @@ public final FileDownload downloadFileWithPresignedUrl(PresignedDownloadFileRequ progressUpdater.transferInitiated(); responseTransformer = isS3ClientMultipartEnabled() + && presignedDownloadFileRequest.presignedUrlDownloadRequest().range() == null ? progressUpdater.wrapForNonSerialFileDownload( - responseTransformer, GetObjectRequest.builder().build()) + responseTransformer, GetObjectRequest.builder().build()) : progressUpdater.wrapResponseTransformer(responseTransformer); progressUpdater.registerCompletion(returnFuture); @@ -652,8 +653,9 @@ public final Download downloadWithPresignedUrl( progressUpdater.transferInitiated(); responseTransformer = isS3ClientMultipartEnabled() + && presignedDownloadRequest.presignedUrlDownloadRequest().range() == null ? progressUpdater.wrapForNonSerialFileDownload( - responseTransformer, GetObjectRequest.builder().build()) + responseTransformer, GetObjectRequest.builder().build()) : progressUpdater.wrapResponseTransformer(responseTransformer); progressUpdater.registerCompletion(returnFuture); diff --git a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/S3TransferManagerPresignedUrlListenerWiremockTest.java b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/S3TransferManagerPresignedUrlListenerWiremockTest.java new file mode 100644 index 00000000000..8a1ee54582b --- /dev/null +++ b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/S3TransferManagerPresignedUrlListenerWiremockTest.java @@ -0,0 +1,253 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.transfer.s3.internal; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.get; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.times; + +import com.github.tomakehurst.wiremock.client.WireMock; +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import java.io.IOException; +import java.net.URI; +import java.net.URL; +import java.util.concurrent.CompletionException; +import java.util.stream.Stream; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.ArgumentMatchers; +import org.mockito.Mockito; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.presignedurl.model.PresignedUrlDownloadRequest; +import software.amazon.awssdk.testutils.RandomTempFile; +import software.amazon.awssdk.transfer.s3.S3TransferManager; +import software.amazon.awssdk.transfer.s3.model.Download; +import software.amazon.awssdk.transfer.s3.model.FileDownload; +import software.amazon.awssdk.transfer.s3.model.PresignedDownloadFileRequest; +import software.amazon.awssdk.transfer.s3.model.PresignedDownloadRequest; +import software.amazon.awssdk.transfer.s3.progress.TransferListener; + +/** + * Tests that TransferListener callbacks fire correctly for presigned URL downloads + * with both multipart-enabled and non-multipart clients. + */ +@WireMockTest +public class S3TransferManagerPresignedUrlListenerWiremockTest { + + private static URI testEndpoint; + private static RandomTempFile testFile; + + @BeforeAll + public static void init(WireMockRuntimeInfo wm) throws IOException { + testEndpoint = URI.create(wm.getHttpBaseUrl()); + testFile = new RandomTempFile("presigned-listener-test", 1024); + } + + @BeforeEach + void resetWireMock() { + WireMock.reset(); + } + + private static S3AsyncClient s3AsyncClient(boolean multipartEnabled) { + return S3AsyncClient.builder() + .multipartEnabled(multipartEnabled) + .region(Region.US_EAST_1) + .endpointOverride(testEndpoint) + .credentialsProvider( + StaticCredentialsProvider.create(AwsBasicCredentials.create("key", "secret"))) + .build(); + } + + static Stream presignedUrlTestCases() { + return Stream.of( + Arguments.of(true, "toFile", null), + Arguments.of(true, "toFile", "bytes=0-511"), + Arguments.of(true, "toBytes", null), + Arguments.of(true, "toBytes", "bytes=0-511"), + Arguments.of(false, "toFile", null), + Arguments.of(false, "toFile", "bytes=0-511"), + Arguments.of(false, "toBytes", null), + Arguments.of(false, "toBytes", "bytes=0-511") + ); + } + + @ParameterizedTest(name = "presignedUrlDownload_multipart={0}_type={1}_range={2}") + @MethodSource("presignedUrlTestCases") + void presignedUrlDownload_shouldInvokeListener(boolean multipartEnabled, String type, String range) throws Exception { + S3AsyncClient s3Async = s3AsyncClient(multipartEnabled); + S3TransferManager tm = new GenericS3TransferManager(s3Async, mock(UploadDirectoryHelper.class), + mock(TransferManagerConfiguration.class), + mock(DownloadDirectoryHelper.class)); + + byte[] responseBody = new byte[512]; + stubFor(get(urlPathEqualTo("/presigned-key")).willReturn(aResponse() + .withStatus(206) + .withHeader("Content-Length", "512") + .withHeader("Content-Range", "bytes 0-511/512") + .withHeader("ETag", "\"test-etag\"") + .withBody(responseBody))); + + TransferListener listener = mock(TransferListener.class); + URL presignedUrl = new URL(testEndpoint + "/presigned-key?X-Amz-Algorithm=AWS4-HMAC-SHA256"); + + PresignedUrlDownloadRequest.Builder requestBuilder = PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl); + if (range != null) { + requestBuilder.range(range); + } + + if ("toFile".equals(type)) { + FileDownload download = tm.downloadFileWithPresignedUrl( + PresignedDownloadFileRequest.builder() + .presignedUrlDownloadRequest(requestBuilder.build()) + .destination(testFile.toPath()) + .addTransferListener(listener) + .build()); + download.completionFuture().join(); + } else { + Download download = tm.downloadWithPresignedUrl( + PresignedDownloadRequest.builder() + .presignedUrlDownloadRequest(requestBuilder.build()) + .responseTransformer(AsyncResponseTransformer.toBytes()) + .addTransferListener(listener) + .build()); + download.completionFuture().join(); + } + + Mockito.verify(listener, timeout(1000).times(1)).transferInitiated(ArgumentMatchers.any()); + Mockito.verify(listener, timeout(1000).atLeastOnce()).bytesTransferred(ArgumentMatchers.any()); + + tm.close(); + s3Async.close(); + } + + static Stream presignedUrlFailureTestCases() { + return Stream.of( + Arguments.of(true, "toFile"), + Arguments.of(true, "toBytes"), + Arguments.of(false, "toFile"), + Arguments.of(false, "toBytes") + ); + } + + @ParameterizedTest(name = "presignedUrlDownload_failure_multipart={0}_type={1}") + @MethodSource("presignedUrlFailureTestCases") + void presignedUrlDownload_failure_shouldInvokeListener(boolean multipartEnabled, String type) throws Exception { + S3AsyncClient s3Async = s3AsyncClient(multipartEnabled); + S3TransferManager tm = new GenericS3TransferManager(s3Async, mock(UploadDirectoryHelper.class), + mock(TransferManagerConfiguration.class), + mock(DownloadDirectoryHelper.class)); + + stubFor(get(urlPathEqualTo("/presigned-key")) + .willReturn(aResponse().withStatus(404) + .withBody("TestErrorTest failure"))); + + TransferListener listener = mock(TransferListener.class); + URL presignedUrl = new URL(testEndpoint + "/presigned-key?X-Amz-Algorithm=AWS4-HMAC-SHA256"); + + if ("toFile".equals(type)) { + FileDownload download = tm.downloadFileWithPresignedUrl( + PresignedDownloadFileRequest.builder() + .presignedUrlDownloadRequest(PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build()) + .destination(testFile.toPath()) + .addTransferListener(listener) + .build()); + assertThatExceptionOfType(CompletionException.class).isThrownBy(() -> download.completionFuture().join()); + } else { + Download download = tm.downloadWithPresignedUrl( + PresignedDownloadRequest.builder() + .presignedUrlDownloadRequest(PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build()) + .responseTransformer(AsyncResponseTransformer.toBytes()) + .addTransferListener(listener) + .build()); + assertThatExceptionOfType(CompletionException.class).isThrownBy(() -> download.completionFuture().join()); + } + + Mockito.verify(listener, timeout(1000).times(1)).transferInitiated(ArgumentMatchers.any()); + Mockito.verify(listener, timeout(1000).times(1)).transferFailed(ArgumentMatchers.any()); + Mockito.verify(listener, times(0)).transferComplete(ArgumentMatchers.any()); + + tm.close(); + s3Async.close(); + } + + @ParameterizedTest(name = "presignedUrlDownload_cancelled_multipart={0}_type={1}") + @MethodSource("presignedUrlFailureTestCases") + void presignedUrlDownload_cancelled_shouldInvokeTransferFailed(boolean multipartEnabled, String type) throws Exception { + S3AsyncClient s3Async = s3AsyncClient(multipartEnabled); + S3TransferManager tm = new GenericS3TransferManager(s3Async, mock(UploadDirectoryHelper.class), + mock(TransferManagerConfiguration.class), + mock(DownloadDirectoryHelper.class)); + + // Slow response to keep request in-flight during cancellation + stubFor(get(urlPathEqualTo("/presigned-key")).willReturn(aResponse() + .withStatus(206) + .withHeader("Content-Length", "512") + .withHeader("Content-Range", "bytes 0-511/512") + .withHeader("ETag", "\"test-etag\"") + .withBody(new byte[512]) + .withFixedDelay(5000))); + + TransferListener listener = mock(TransferListener.class); + URL presignedUrl = new URL(testEndpoint + "/presigned-key?X-Amz-Algorithm=AWS4-HMAC-SHA256"); + + if ("toFile".equals(type)) { + FileDownload download = tm.downloadFileWithPresignedUrl( + PresignedDownloadFileRequest.builder() + .presignedUrlDownloadRequest(PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build()) + .destination(testFile.toPath()) + .addTransferListener(listener) + .build()); + download.completionFuture().cancel(true); + } else { + Download download = tm.downloadWithPresignedUrl( + PresignedDownloadRequest.builder() + .presignedUrlDownloadRequest(PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build()) + .responseTransformer(AsyncResponseTransformer.toBytes()) + .addTransferListener(listener) + .build()); + download.completionFuture().cancel(true); + } + + Mockito.verify(listener, timeout(1000).times(1)).transferInitiated(ArgumentMatchers.any()); + Mockito.verify(listener, timeout(1000).times(1)).transferFailed(ArgumentMatchers.any()); + Mockito.verify(listener, times(0)).transferComplete(ArgumentMatchers.any()); + + tm.close(); + s3Async.close(); + } +}