Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions extensions/aws/s3/MinifiToAwsInputStream.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/**
*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "MinifiToAwsInputStream.h"

#include <algorithm>
#include <span>

namespace org::apache::nifi::minifi::aws::s3 {

MinifiInputStreamBuf::int_type MinifiInputStreamBuf::underflow() {
if (gptr() < egptr()) {
return traits_type::to_int_type(*gptr());
}
const uint64_t stream_pos = stream_->tell();
if (stream_pos >= start_pos_ + content_length_) {
return traits_type::eof();
}
const auto remaining = (start_pos_ + content_length_) - stream_pos;
const auto to_read = std::min<uint64_t>(utils::configuration::DEFAULT_BUFFER_SIZE, remaining);
const auto bytes_read = stream_->read(std::span(reinterpret_cast<std::byte*>(buffer_.data()), gsl::narrow<size_t>(to_read)));
if (io::isError(bytes_read)) {
owner_->setstate(std::ios_base::badbit);
return traits_type::eof();
}
if (bytes_read == 0) {
return traits_type::eof();
}
setg(buffer_.data(), buffer_.data(), buffer_.data() + bytes_read);
return traits_type::to_int_type(*gptr());
}

MinifiInputStreamBuf::pos_type MinifiInputStreamBuf::seekoff(off_type off, std::ios_base::seekdir way, std::ios_base::openmode which) {
if (!(which & std::ios_base::in)) {
return {off_type(-1)};
}
pos_type new_virtual_pos;
if (way == std::ios_base::beg) {
new_virtual_pos = pos_type(off);
} else if (way == std::ios_base::cur) {
const auto phys_pos = static_cast<off_type>(stream_->tell()) - static_cast<off_type>(egptr() - gptr());
new_virtual_pos = pos_type(phys_pos - static_cast<off_type>(start_pos_) + off);
} else {
new_virtual_pos = pos_type(static_cast<off_type>(content_length_) + off);
}
return seekpos(new_virtual_pos, which);
}

MinifiInputStreamBuf::pos_type MinifiInputStreamBuf::seekpos(pos_type pos, std::ios_base::openmode which) {
if (!(which & std::ios_base::in)) {
return {off_type(-1)};
}
stream_->seek(start_pos_ + static_cast<size_t>(off_type(pos)));
setg(buffer_.data(), buffer_.data(), buffer_.data()); // invalidate read buffer
return pos;
}
Comment on lines +63 to +70
Copy link
Copy Markdown
Contributor

@fgerlits fgerlits May 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check that 0 <= pos <= content_length_? (or should that be < content_length_?)


} // namespace org::apache::nifi::minifi::aws::s3
60 changes: 60 additions & 0 deletions extensions/aws/s3/MinifiToAwsInputStream.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/**
*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <streambuf>
#include <iostream>
#include <memory>
#include <vector>

#include "utils/ConfigurationUtils.h"
#include "minifi-cpp/io/InputStream.h"
#include "minifi-cpp/utils/gsl.h"

namespace org::apache::nifi::minifi::aws::s3 {

class MinifiInputStreamBuf : public std::streambuf {
public:
MinifiInputStreamBuf(std::shared_ptr<io::InputStream> stream, uint64_t content_length, gsl::not_null<std::basic_ios<char>*> owner)
: stream_(std::move(stream)),
start_pos_(stream_->tell()),
content_length_(content_length),
buffer_(utils::configuration::DEFAULT_BUFFER_SIZE),
owner_(owner) {}

protected:
int_type underflow() override;
pos_type seekoff(off_type off, std::ios_base::seekdir way, std::ios_base::openmode which) override;
pos_type seekpos(pos_type pos, std::ios_base::openmode which) override;

private:
std::shared_ptr<io::InputStream> stream_;
uint64_t start_pos_;
uint64_t content_length_;
std::vector<char> buffer_;
gsl::not_null<std::basic_ios<char>*> owner_;
};

class MinifiToAwsInputStream : private MinifiInputStreamBuf, public std::basic_iostream<char> {
public:
MinifiToAwsInputStream(std::shared_ptr<io::InputStream> stream, uint64_t content_length)
: MinifiInputStreamBuf(std::move(stream), content_length, gsl::not_null<std::basic_ios<char>*>(static_cast<std::basic_ios<char>*>(this))),
std::basic_iostream<char>(static_cast<MinifiInputStreamBuf*>(this)) {}
};

} // namespace org::apache::nifi::minifi::aws::s3
42 changes: 12 additions & 30 deletions extensions/aws/s3/S3Wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <vector>
#include <algorithm>

#include "MinifiToAwsInputStream.h"
#include "S3ClientRequestSender.h"
#include "utils/ArrayUtils.h"
#include "utils/StringUtils.h"
Expand Down Expand Up @@ -68,32 +69,11 @@ std::string S3Wrapper::getEncryptionString(Aws::S3Crt::Model::ServerSideEncrypti
return "";
}

std::shared_ptr<Aws::StringStream> S3Wrapper::readFlowFileStream(const std::shared_ptr<io::InputStream>& stream, uint64_t read_limit, uint64_t& read_size_out) {
std::array<std::byte, BUFFER_SIZE> buffer{};
auto data_stream = std::make_shared<Aws::StringStream>();
uint64_t read_size = 0;
while (read_size < read_limit) {
const auto next_read_size = (std::min)(read_limit - read_size, uint64_t{BUFFER_SIZE});
const auto read_ret = stream->read(std::span(buffer).subspan(0, next_read_size));
if (io::isError(read_ret)) {
throw StreamReadException("Reading flow file inputstream failed!");
}
if (read_ret > 0) {
data_stream->write(reinterpret_cast<char*>(buffer.data()), gsl::narrow<std::streamsize>(read_ret));
read_size += read_ret;
} else {
break;
}
}
read_size_out = read_size;
return data_stream;
}

std::optional<PutObjectResult> S3Wrapper::putObject(const PutObjectRequestParameters& put_object_params, const std::shared_ptr<io::InputStream>& stream, uint64_t flow_size) {
uint64_t read_size{};
auto data_stream = readFlowFileStream(stream, flow_size, read_size);
auto request = createPutObjectRequest<Aws::S3Crt::Model::PutObjectRequest>(put_object_params);
request.SetBody(data_stream);
auto aws_stream = std::make_shared<MinifiToAwsInputStream>(stream, flow_size);
request.SetBody(aws_stream);
request.SetContentLength(static_cast<long long>(flow_size)); // NOLINT(runtime/int,google-runtime-int) AWS SDK expects long long for content length
Comment thread
lordgamez marked this conversation as resolved.

auto aws_result = request_sender_->sendPutObjectRequest(request);
if (!aws_result) {
Expand All @@ -120,32 +100,34 @@ std::optional<S3Wrapper::UploadPartsResult> S3Wrapper::uploadParts(const PutObje
const size_t start_part = upload_state.uploaded_parts + 1;
const size_t last_part = start_part + part_count - 1;
for (size_t part_number = start_part; part_number <= last_part; ++part_number) {
uint64_t read_size{};
const auto remaining = flow_size - total_read;
const auto next_read_size = std::min(remaining, upload_state.part_size);
auto stream_ptr = readFlowFileStream(stream, next_read_size, read_size);
total_read += read_size;
auto aws_stream = std::make_shared<MinifiToAwsInputStream>(stream, next_read_size);

auto upload_part_request = Aws::S3Crt::Model::UploadPartRequest{}
.WithBucket(put_object_params.bucket)
.WithKey(put_object_params.object_key)
.WithPartNumber(gsl::narrow<int>(part_number))
.WithUploadId(upload_state.upload_id)
.WithChecksumAlgorithm(put_object_params.checksum_algorithm);
upload_part_request.SetBody(stream_ptr);
upload_part_request.SetBody(aws_stream);
upload_part_request.SetContentLength(static_cast<long long>(next_read_size)); // NOLINT(runtime/int,google-runtime-int) AWS SDK expects long long for content length

Aws::Utils::ByteBuffer part_md5(Aws::Utils::HashingUtils::CalculateMD5(*stream_ptr));
Aws::Utils::ByteBuffer part_md5(Aws::Utils::HashingUtils::CalculateMD5(*aws_stream));
upload_part_request.SetContentMD5(Aws::Utils::HashingUtils::Base64Encode(part_md5));
// Reset to part start so the SDK reads the full part during the upload request.
aws_stream->seekg(0, std::ios::beg);

auto upload_part_result = request_sender_->sendUploadPartRequest(upload_part_request);
if (!upload_part_result) {
logger_->log_error("Failed to upload part {} of {} of S3 object with key '{}'", part_number, last_part, put_object_params.object_key);
return std::nullopt;
}
total_read += next_read_size;
result.part_etags.push_back(upload_part_result->GetETag());
upload_state.uploaded_etags.push_back(upload_part_result->GetETag());
upload_state.uploaded_parts += 1;
upload_state.uploaded_size += read_size;
upload_state.uploaded_size += next_read_size;
multipart_upload_storage_->storeState(put_object_params.bucket, put_object_params.object_key, upload_state);
logger_->log_info("Uploaded part {} of {} S3 object with key '{}'", part_number, last_part, put_object_params.object_key);
}
Expand Down
2 changes: 0 additions & 2 deletions extensions/aws/s3/S3Wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include <map>
#include <memory>
#include <optional>
#include <sstream>
#include <string>
#include <unordered_map>
#include <utility>
Expand Down Expand Up @@ -301,7 +300,6 @@ class S3Wrapper {

static int64_t writeFetchedBody(Aws::IOStream& source, int64_t data_size, io::OutputStream& output);
static std::string getEncryptionString(Aws::S3Crt::Model::ServerSideEncryption encryption);
static std::shared_ptr<Aws::StringStream> readFlowFileStream(const std::shared_ptr<io::InputStream>& stream, uint64_t read_limit, uint64_t& read_size_out);

std::optional<std::vector<ListedObjectAttributes>> listVersions(const ListRequestParameters& params);
std::optional<std::vector<ListedObjectAttributes>> listObjects(const ListRequestParameters& params);
Expand Down
9 changes: 9 additions & 0 deletions extensions/aws/tests/MockS3RequestSender.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#pragma once

#include <limits>
#include <map>
#include <optional>
#include <string>
Expand Down Expand Up @@ -232,6 +233,10 @@ class MockS3RequestSender : public minifi::aws::s3::S3RequestSender {
fail_on_part_ = 0;
return std::nullopt;
}
// Consume the body like the real SDK, allowing the next part to start at the correct position
if (auto body = request.GetBody()) {
body->ignore(std::numeric_limits<std::streamsize>::max());
}
upload_part_requests.push_back(request);
Aws::S3Crt::Model::UploadPartResult result;
result.SetETag("etag" + std::to_string(etag_counter_));
Expand Down Expand Up @@ -294,6 +299,10 @@ class MockS3RequestSender : public minifi::aws::s3::S3RequestSender {
}

static std::string getUploadPartRequestBody(const Aws::S3Crt::Model::UploadPartRequest& upload_part_request) {
// Seek to the beginning of this part's window before reading, because the
// underlying io::InputStream is shared across all parts and may be positioned
// elsewhere by the time this helper is called.
upload_part_request.GetBody()->seekg(0);
std::istreambuf_iterator<char> buf_it;
return std::string(std::istreambuf_iterator<char>(*upload_part_request.GetBody()), buf_it);
}
Expand Down
Loading