diff --git a/cmake/sdksCommon.cmake b/cmake/sdksCommon.cmake index d6e18bfa7e6..1fb754a66b1 100644 --- a/cmake/sdksCommon.cmake +++ b/cmake/sdksCommon.cmake @@ -112,7 +112,7 @@ list(APPEND SDK_TEST_PROJECT_LIST "s3control:tests/aws-cpp-sdk-s3control-integra list(APPEND SDK_TEST_PROJECT_LIST "sns:tests/aws-cpp-sdk-sns-integration-tests") list(APPEND SDK_TEST_PROJECT_LIST "sqs:tests/aws-cpp-sdk-sqs-integration-tests") list(APPEND SDK_TEST_PROJECT_LIST "sqs:tests/aws-cpp-sdk-sqs-unit-tests") -list(APPEND SDK_TEST_PROJECT_LIST "transfer:tests/aws-cpp-sdk-transfer-tests") +list(APPEND SDK_TEST_PROJECT_LIST "transfer:tests/aws-cpp-sdk-transfer-tests,tests/aws-cpp-sdk-transfer-unit-tests") list(APPEND SDK_TEST_PROJECT_LIST "text-to-speech:tests/aws-cpp-sdk-text-to-speech-tests,tests/aws-cpp-sdk-polly-sample") list(APPEND SDK_TEST_PROJECT_LIST "timestream-query:tests/aws-cpp-sdk-timestream-query-unit-tests") list(APPEND SDK_TEST_PROJECT_LIST "transcribestreaming:tests/aws-cpp-sdk-transcribestreaming-integ-tests") diff --git a/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp b/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp index 556a77b683a..4bd350d9605 100644 --- a/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp +++ b/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp @@ -825,6 +825,33 @@ namespace Aws return rangeStream.str(); } + static bool VerifyContentRange(const Aws::String& requestedRange, const Aws::String& responseContentRange) + { + if (requestedRange.empty() || responseContentRange.empty()) + { + return false; + } + + if (requestedRange.find("bytes=") != 0) + { + return false; + } + Aws::String requestRange = requestedRange.substr(6); + + if (responseContentRange.find("bytes ") != 0) + { + return false; + } + Aws::String responseRange = responseContentRange.substr(6); + size_t slashPos = responseRange.find('/'); + if (slashPos != Aws::String::npos) + { + responseRange = responseRange.substr(0, slashPos); + } + + return requestRange == responseRange; + } + void TransferManager::DoSinglePartDownload(const std::shared_ptr& handle) { auto queuedParts = handle->GetQueuedParts(); @@ -1091,7 +1118,6 @@ namespace Aws const std::shared_ptr& context) { AWS_UNREFERENCED_PARAM(client); - AWS_UNREFERENCED_PARAM(request); std::shared_ptr transferContext = std::const_pointer_cast(std::static_pointer_cast(context)); @@ -1108,33 +1134,57 @@ namespace Aws handle->SetError(outcome.GetError()); TriggerErrorCallback(handle, outcome.GetError()); } - else + else if (request.RangeHasBeenSet()) { - if(handle->ShouldContinue()) - { - Aws::IOStream* bufferStream = partState->GetDownloadPartStream(); - assert(bufferStream); - - Aws::String errMsg{handle->WritePartToDownloadStream(bufferStream, partState->GetRangeBegin())}; - if (errMsg.empty()) { - handle->ChangePartToCompleted(partState, outcome.GetResult().GetETag()); - } else { - Aws::Client::AWSError error(Aws::S3::S3Errors::INTERNAL_FAILURE, - "InternalFailure", errMsg, false); - AWS_LOGSTREAM_ERROR(CLASS_TAG, "Transfer handle [" << handle->GetId() - << "] Failed to download object in Bucket: [" - << handle->GetBucketName() << "] with Key: [" << handle->GetKey() - << "] " << errMsg); - handle->ChangePartToFailed(partState); - handle->SetError(error); - TriggerErrorCallback(handle, error); + const auto& requestedRange = request.GetRange(); + const auto& responseContentRange = outcome.GetResult().GetContentRange(); + + if (responseContentRange.empty() or !VerifyContentRange(requestedRange, responseContentRange)) { + Aws::Client::AWSError error(Aws::S3::S3Errors::INTERNAL_FAILURE, + "ContentRangeMismatch", + "ContentRange in response does not match requested range", + false); + AWS_LOGSTREAM_ERROR(CLASS_TAG, "Transfer handle [" << handle->GetId() + << "] ContentRange mismatch. Requested: [" << requestedRange + << "] Received: [" << responseContentRange << "]"); + handle->ChangePartToFailed(partState); + handle->SetError(error); + TriggerErrorCallback(handle, error); + handle->Cancel(); + + if(partState->GetDownloadBuffer()) + { + m_bufferManager.Release(partState->GetDownloadBuffer()); + partState->SetDownloadBuffer(nullptr); } - } - else - { + return; + } + + if(handle->ShouldContinue()) + { + Aws::IOStream* bufferStream = partState->GetDownloadPartStream(); + assert(bufferStream); + + Aws::String errMsg{handle->WritePartToDownloadStream(bufferStream, partState->GetRangeBegin())}; + if (errMsg.empty()) { + handle->ChangePartToCompleted(partState, outcome.GetResult().GetETag()); + } else { + Aws::Client::AWSError error(Aws::S3::S3Errors::INTERNAL_FAILURE, + "InternalFailure", errMsg, false); + AWS_LOGSTREAM_ERROR(CLASS_TAG, "Transfer handle [" << handle->GetId() + << "] Failed to download object in Bucket: [" + << handle->GetBucketName() << "] with Key: [" << handle->GetKey() + << "] " << errMsg); handle->ChangePartToFailed(partState); + handle->SetError(error); + TriggerErrorCallback(handle, error); } } + else + { + handle->ChangePartToFailed(partState); + } + } // buffer cleanup if(partState->GetDownloadBuffer()) diff --git a/tests/aws-cpp-sdk-transfer-tests/TransferTests.cpp b/tests/aws-cpp-sdk-transfer-tests/TransferTests.cpp index ab061815ca4..dc2e3a0cc38 100644 --- a/tests/aws-cpp-sdk-transfer-tests/TransferTests.cpp +++ b/tests/aws-cpp-sdk-transfer-tests/TransferTests.cpp @@ -2328,6 +2328,40 @@ TEST_P(TransferTests, TransferManager_TestRelativePrefix) } } +TEST_P(TransferTests, TransferManager_ContentRangeVerificationTest) +{ + const Aws::String RandomFileName = Aws::Utils::UUID::RandomUUID(); + Aws::String testFileName = MakeFilePath(RandomFileName.c_str()); + ScopedTestFile testFile(testFileName, MEDIUM_TEST_SIZE, testString); + + TransferManagerConfiguration transferManagerConfig(m_executor.get()); + transferManagerConfig.s3Client = m_s3Clients[GetParam()]; + auto transferManager = TransferManager::Create(transferManagerConfig); + + std::shared_ptr uploadPtr = transferManager->UploadFile(testFileName, GetTestBucketName(), RandomFileName, "text/plain", Aws::Map()); + uploadPtr->WaitUntilFinished(); + ASSERT_EQ(TransferStatus::COMPLETED, uploadPtr->GetStatus()); + ASSERT_TRUE(WaitForObjectToPropagate(GetTestBucketName(), RandomFileName.c_str())); + + auto downloadFileName = MakeDownloadFileName(RandomFileName); + auto createStreamFn = [=](){ +#ifdef _MSC_VER + return Aws::New(ALLOCATION_TAG, Aws::Utils::StringUtils::ToWString(downloadFileName.c_str()).c_str(), std::ios_base::out | std::ios_base::in | std::ios_base::binary | std::ios_base::trunc); +#else + return Aws::New(ALLOCATION_TAG, downloadFileName.c_str(), std::ios_base::out | std::ios_base::in | std::ios_base::binary | std::ios_base::trunc); +#endif + }; + + uint64_t offset = 1024; + uint64_t partSize = 2048; + std::shared_ptr downloadPtr = transferManager->DownloadFile(GetTestBucketName(), RandomFileName, offset, partSize, createStreamFn); + + downloadPtr->WaitUntilFinished(); + ASSERT_EQ(TransferStatus::COMPLETED, downloadPtr->GetStatus()); + ASSERT_EQ(partSize, downloadPtr->GetBytesTotalSize()); + ASSERT_EQ(partSize, downloadPtr->GetBytesTransferred()); +} + INSTANTIATE_TEST_SUITE_P(Https, TransferTests, testing::Values(TestType::Https)); INSTANTIATE_TEST_SUITE_P(Http, TransferTests, testing::Values(TestType::Http)); diff --git a/tests/aws-cpp-sdk-transfer-unit-tests/CMakeLists.txt b/tests/aws-cpp-sdk-transfer-unit-tests/CMakeLists.txt new file mode 100644 index 00000000000..867404b1fcc --- /dev/null +++ b/tests/aws-cpp-sdk-transfer-unit-tests/CMakeLists.txt @@ -0,0 +1,31 @@ +add_project(aws-cpp-sdk-transfer-unit-tests + "Unit Tests for the Transfer Manager" + aws-cpp-sdk-transfer + aws-cpp-sdk-s3 + testing-resources + aws_test_main + aws-cpp-sdk-core) + +add_definitions(-DRESOURCES_DIR="${CMAKE_CURRENT_SOURCE_DIR}/resources") + +if(MSVC AND BUILD_SHARED_LIBS) + add_definitions(-DGTEST_LINKED_AS_SHARED_LIBRARY=1) +endif() + +enable_testing() + +if(PLATFORM_ANDROID AND BUILD_SHARED_LIBS) + add_library(${PROJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/TransferUnitTests.cpp) +else() + add_executable(${PROJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/TransferUnitTests.cpp) +endif() + +set_compiler_flags(${PROJECT_NAME}) +set_compiler_warnings(${PROJECT_NAME}) + +target_link_libraries(${PROJECT_NAME} ${PROJECT_LIBS}) + +if(MSVC AND BUILD_SHARED_LIBS) + set_target_properties(${PROJECT_NAME} PROPERTIES LINK_FLAGS "/DELAYLOAD:aws-cpp-sdk-transfer.dll /DELAYLOAD:aws-cpp-sdk-core.dll") + target_link_libraries(${PROJECT_NAME} delayimp.lib) +endif() diff --git a/tests/aws-cpp-sdk-transfer-unit-tests/TransferUnitTests.cpp b/tests/aws-cpp-sdk-transfer-unit-tests/TransferUnitTests.cpp new file mode 100644 index 00000000000..8446473eff7 --- /dev/null +++ b/tests/aws-cpp-sdk-transfer-unit-tests/TransferUnitTests.cpp @@ -0,0 +1,71 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace Aws; +using namespace Aws::S3; +using namespace Aws::S3::Model; +using namespace Aws::Transfer; +using namespace Aws::Utils::Threading; + +const char* ALLOCATION_TAG = "TransferUnitTest"; + +class MockS3Client : public S3Client { +public: + MockS3Client() : S3Client(){}; + + GetObjectOutcome GetObject(const GetObjectRequest&) const override { + GetObjectResult result; + // Return wrong range to trigger validation failure + result.SetContentRange("bytes 1024-2047/2048"); + auto stream = Aws::New(ALLOCATION_TAG); + *stream << "mock data"; + result.ReplaceBody(stream); + return GetObjectOutcome(std::move(result)); + } +}; + +class TransferUnitTest : public testing::Test { +protected: + void SetUp() override { + executor = Aws::MakeShared(ALLOCATION_TAG, 1); + mockS3Client = Aws::MakeShared(ALLOCATION_TAG); + } + + static void SetUpTestSuite() { + InitAPI(_options); + } + + static void TearDownTestSuite() { + ShutdownAPI(_options); + } + + std::shared_ptr executor; + std::shared_ptr mockS3Client; + static SDKOptions _options; +}; + +SDKOptions TransferUnitTest::_options; + +TEST_F(TransferUnitTest, ContentValidationShouldFail) { + TransferManagerConfiguration config(executor.get()); + config.s3Client = mockS3Client; + auto transferManager = TransferManager::Create(config); + + auto createStreamFn = []() { + return Aws::New(ALLOCATION_TAG); + }; + + // Request bytes 0-1023 but mock returns 1024-2047, should fail validation + auto handle = transferManager->DownloadFile("test-bucket", "test-key", 0, 1024, createStreamFn); + handle->WaitUntilFinished(); + + EXPECT_EQ(TransferStatus::FAILED, handle->GetStatus()); +}