Skip to content

Commit 0b8180e

Browse files
authored
[TensorRT RTX EP] Be able to specify aux streams (#26569)
### Description Be able to specify auxiliary streams to TensorRT RTX EP. ### Motivation and Context In some use cases, we want to have full control over all the streams used by TRT-RTX, even auxiliary ones.
1 parent 655b69c commit 0b8180e

File tree

5 files changed

+41
-0
lines changed

5 files changed

+41
-0
lines changed

include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
* - `kDeviceId`: Specifies the GPU device ID to use.
88
* - `kHasUserComputeStream`: Indicates whether a user-provided compute stream is used.
99
* - `kUserComputeStream`: Specifies the user-provided compute stream.
10+
* - `kUserAuxStreamArray`: Specifies the user-provided aux stream.
1011
* - `kMaxWorkspaceSize`: Sets the maximum workspace size for GPU memory allocation.
1112
* - 'kMaxSharedMemSize': Sets the maximum amount of shared memory that TensorRT kernels are allowed to use
13+
* - `kLengthAuxStreamArray`: Specifies the length/size of the auxiliary streams array (kUserAuxStreamArray). Also sets the maximum number of auxiliary streams for TensorRT execution.
1214
* - `kDumpSubgraphs`: Enables or disables dumping of subgraphs for debugging.
1315
* - `kDetailedBuildLog`: Enables or disables detailed build logs for debugging.
1416
* - `kProfilesMinShapes`: Specifies the minimum shapes for profiling.
@@ -24,8 +26,10 @@ namespace provider_option_names {
2426
constexpr const char* kDeviceId = "device_id";
2527
constexpr const char* kHasUserComputeStream = "has_user_compute_stream";
2628
constexpr const char* kUserComputeStream = "user_compute_stream";
29+
constexpr const char* kUserAuxStreamArray = "user_aux_stream_array";
2730
constexpr const char* kMaxWorkspaceSize = "nv_max_workspace_size";
2831
constexpr const char* kMaxSharedMemSize = "nv_max_shared_mem_size";
32+
constexpr const char* kLengthAuxStreamArray = "nv_length_aux_stream_array";
2933
constexpr const char* kDumpSubgraphs = "nv_dump_subgraphs";
3034
constexpr const char* kDetailedBuildLog = "nv_detailed_build_log";
3135
constexpr const char* kProfilesMinShapes = "nv_profile_min_shapes";

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,17 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info)
984984
stream_ = nullptr; // Will be created in compute function
985985
}
986986

987+
if (info.user_aux_stream_array != nullptr) {
988+
if (info.auxiliary_streams <= 0) {
989+
ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Auxiliary streams must be greater than 0 when using external auxiliary streams"));
990+
}
991+
external_aux_streams_ = true;
992+
aux_streams_ = reinterpret_cast<cudaStream_t*>(info.user_aux_stream_array);
993+
} else {
994+
external_aux_streams_ = false;
995+
aux_streams_ = nullptr;
996+
}
997+
987998
std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes;
988999

9891000
// incase the EP context is dumped the engine cache has to be enabled
@@ -3039,6 +3050,11 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
30393050
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP select an optimization profile for the current context failed");
30403051
}
30413052

3053+
// Set auxiliary stream if provided by user
3054+
if (external_aux_streams_ && aux_streams_ != nullptr) {
3055+
trt_context->setAuxStreams(aux_streams_, (int32_t)auxiliary_streams_);
3056+
}
3057+
30423058
// Check before using trt_engine
30433059
if (trt_engine == nullptr) {
30443060
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "No engine is found.");
@@ -3450,6 +3466,11 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra
34503466
}
34513467
}
34523468

3469+
// Set auxiliary stream if provided by user
3470+
if (external_aux_streams_ && aux_streams_ != nullptr) {
3471+
trt_context->setAuxStreams(aux_streams_, (int32_t)auxiliary_streams_);
3472+
}
3473+
34533474
// Start CUDA graph capture with the correct stream
34543475
// Note: We need to set the stream and start capture here because this is where we have access to the actual compute stream
34553476
// Get the graph annotation ID that was stored during OnRunStart

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ class NvExecutionProvider : public IExecutionProvider {
349349
mutable NvExecutionProviderInfo info_;
350350
bool external_stream_ = false;
351351
cudaStream_t stream_ = nullptr;
352+
bool external_aux_streams_ = false;
353+
cudaStream_t* aux_streams_ = nullptr;
352354
int max_partition_iterations_ = 1000;
353355
size_t min_subgraph_size_ = 1;
354356
size_t max_workspace_size_ = 0;

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi
1616
const ConfigOptions& session_options) {
1717
NvExecutionProviderInfo info{};
1818
void* user_compute_stream = nullptr;
19+
void* user_aux_stream_array = nullptr;
1920
void* onnx_bytestream = nullptr;
2021
void* external_data_bytestream = nullptr;
2122
ORT_THROW_IF_ERROR(
@@ -41,8 +42,17 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi
4142
user_compute_stream = reinterpret_cast<void*>(address);
4243
return Status::OK();
4344
})
45+
.AddValueParser(
46+
nv::provider_option_names::kUserAuxStreamArray,
47+
[&user_aux_stream_array](const std::string& value_str) -> Status {
48+
size_t address;
49+
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
50+
user_aux_stream_array = reinterpret_cast<void*>(address);
51+
return Status::OK();
52+
})
4453
.AddAssignmentToReference(nv::provider_option_names::kMaxWorkspaceSize, info.max_workspace_size)
4554
.AddAssignmentToReference(nv::provider_option_names::kMaxSharedMemSize, info.max_shared_mem_size)
55+
.AddAssignmentToReference(nv::provider_option_names::kLengthAuxStreamArray, info.auxiliary_streams)
4656
.AddAssignmentToReference(nv::provider_option_names::kDumpSubgraphs, info.dump_subgraphs)
4757
.AddAssignmentToReference(nv::provider_option_names::kDetailedBuildLog, info.detailed_build_log)
4858
.AddAssignmentToReference(nv::provider_option_names::kProfilesMinShapes, info.profile_min_shapes)
@@ -56,6 +66,7 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi
5666

5767
info.user_compute_stream = user_compute_stream;
5868
info.has_user_compute_stream = (user_compute_stream != nullptr);
69+
info.user_aux_stream_array = user_aux_stream_array;
5970
info.onnx_bytestream = onnx_bytestream;
6071
info.external_data_bytestream = external_data_bytestream;
6172

@@ -98,8 +109,10 @@ ProviderOptions NvExecutionProviderInfo::ToProviderOptions(const NvExecutionProv
98109
{nv::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
99110
{nv::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
100111
{nv::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.user_compute_stream))},
112+
{nv::provider_option_names::kUserAuxStreamArray, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.user_aux_stream_array))},
101113
{nv::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.max_workspace_size)},
102114
{nv::provider_option_names::kMaxSharedMemSize, MakeStringWithClassicLocale(info.max_shared_mem_size)},
115+
{nv::provider_option_names::kLengthAuxStreamArray, MakeStringWithClassicLocale(info.auxiliary_streams)},
103116
{nv::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.dump_subgraphs)},
104117
{nv::provider_option_names::kDetailedBuildLog, MakeStringWithClassicLocale(info.detailed_build_log)},
105118
{nv::provider_option_names::kProfilesMinShapes, MakeStringWithClassicLocale(info.profile_min_shapes)},

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ struct NvExecutionProviderInfo {
2121
int device_id{0};
2222
bool has_user_compute_stream{false};
2323
void* user_compute_stream{nullptr};
24+
void* user_aux_stream_array{nullptr};
2425
int max_partition_iterations{1000};
2526
int min_subgraph_size{1};
2627
size_t max_workspace_size{0};

0 commit comments

Comments
 (0)