Skip to content

Commit 32d101d

Browse files
qjia7Copilot
andauthored
Enable graph capture for webgpu (#1848)
This PR enables the graph capture for webgpu. It implements CopyDeviceToCpu\CopyCpuToDevice\CopyFrom\Zero functions using the new `CopyTensors` API. The ort part needs to apply this PR [#26450](microsoft/onnxruntime#26450) to make it work for webgpu. Below things will be implemented in following-up PRs to get the full performance gain for graph capture (The original one is #1720). 1. Support UpdateAttentionMask, UpdatePositionIds, and Cast to keep the whole pipeline on gpu. 2. Optimize CopyFrom with offsets --------- Co-authored-by: Copilot <[email protected]>
1 parent 7bf81d5 commit 32d101d

File tree

13 files changed

+189
-22
lines changed

13 files changed

+189
-22
lines changed

.pipelines/nuget-publishing.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ parameters:
6161
- name: ort_version
6262
displayName: 'OnnxRuntime version'
6363
type: string
64-
default: '1.22.0'
64+
default: '1.23.0'
6565

6666
- name: ort_winml_version
6767
displayName: 'Microsoft.WindowsAppSDK.ML Version (should match CMakeList.txt)'
@@ -76,12 +76,12 @@ parameters:
7676
- name: ort_cuda_version
7777
displayName: 'OnnxRuntime GPU version'
7878
type: string
79-
default: '1.22.0'
79+
default: '1.23.0'
8080

8181
- name: ort_dml_version
8282
displayName: 'OnnxRuntime DML version'
8383
type: string
84-
default: '1.22.0'
84+
default: '1.23.0'
8585

8686
- name: cuda_version
8787
displayName: 'CUDA version'

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ if(MSVC)
8080
"$<$<COMPILE_LANGUAGE:C>:/wd4100>"
8181
"$<$<COMPILE_LANGUAGE:CXX>:/wd4100>"
8282

83+
# Suppress warning C4819: file contains character that cannot be represented in current code page
84+
"$<$<COMPILE_LANGUAGE:C>:/wd4819>"
85+
"$<$<COMPILE_LANGUAGE:CXX>:/wd4819>"
86+
8387
# Enable warning level 4 (more aggressive than default /W3)
8488
# Captures more potential bugs or code smells
8589
"$<$<COMPILE_LANGUAGE:C>:/W4>"

cmake/ortlib.cmake

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,16 @@ if(ORT_HOME)
8181
endif()
8282
else()
8383
# If ORT_HOME is not specified, download the onnxruntime headers and libraries from the nightly feed
84-
set(ORT_VERSION "1.22.0")
84+
set(ORT_VERSION "1.23.0")
8585
set(ORT_FEED_ORG_NAME "aiinfra")
8686
set(ORT_FEED_PROJECT "2692857e-05ef-43b4-ba9c-ccf1c22c437c")
8787
set(ORT_NIGHTLY_FEED_ID "7982ae20-ed19-4a35-a362-a96ac99897b7")
8888

8989
if (USE_DML)
90-
set(ORT_VERSION "1.22.0")
90+
set(ORT_VERSION "1.23.0")
9191
set(ORT_PACKAGE_NAME "Microsoft.ML.OnnxRuntime.DirectML")
9292
elseif(USE_CUDA)
93-
set(ORT_VERSION "1.22.0")
93+
set(ORT_VERSION "1.23.0")
9494
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
9595
set(ORT_PACKAGE_NAME "Microsoft.ML.OnnxRuntime.Gpu.Linux")
9696
elseif(WIN32)
@@ -99,7 +99,7 @@ else()
9999
message(FATAL_ERROR "Unsupported platform for CUDA")
100100
endif()
101101
elseif(USE_ROCM)
102-
set(ORT_VERSION "1.22.0")
102+
set(ORT_VERSION "1.23.0")
103103
set(ORT_PACKAGE_NAME "Microsoft.ML.OnnxRuntime.Rocm")
104104
else()
105105
set(ORT_PACKAGE_NAME "Microsoft.ML.OnnxRuntime")

examples/slm_engine/build_scripts/build_deps.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -577,9 +577,9 @@ def main():
577577
ort_home = None
578578
if args.build_ort_from_source:
579579
if args.ort_version_to_use is None:
580-
# If not Windows then use 1.22.0
580+
# If not Windows then use 1.23.0
581581
if platform.system() != "Windows":
582-
args.ort_version_to_use = "v1.22.0"
582+
args.ort_version_to_use = "v1.23.0"
583583
else:
584584
args.ort_version_to_use = "main"
585585
ort_home = build_ort(args, dep_src_dir, artifacts_dir)
@@ -590,7 +590,7 @@ def main():
590590
# The ORT binaries are available as they were downloaded during the GenAI build
591591
# This is the supported version for most platforms
592592
if args.ort_version_to_use is None:
593-
ORT_VERSION = "1.22.0"
593+
ORT_VERSION = "1.23.0"
594594
else:
595595
ORT_VERSION = args.ort_version_to_use
596596
# Copy the ORT artifacts to the artifacts directory.

src/config.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,13 @@ bool IsGraphCaptureEnabled(const Config::SessionOptions& session_options) {
11051105
}
11061106
} else if (provider_options->name == "DML") {
11071107
return true;
1108+
} else if (provider_options->name == "WebGPU") {
1109+
for (const auto& value : provider_options->options) {
1110+
if (value.first == "enableGraphCapture" && value.second == "1") {
1111+
return true;
1112+
}
1113+
}
1114+
return false;
11081115
} else if (provider_options->name == "NvTensorRtRtx") {
11091116
for (const auto& value : provider_options->options) {
11101117
if (value.first == "enable_cuda_graph" && value.second == "1") {

src/models/model.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,9 @@ Model::Model(std::unique_ptr<Config> config) : config_{std::move(config)} {
953953
EnsureDeviceOrtInit(*p_device_, *config_, arena_cfg_);
954954

955955
// Only CUDA, TRT-RTX and DML does every input on the device
956-
if (p_device_->GetType() == DeviceType::CUDA || p_device_->GetType() == DeviceType::DML || p_device_->GetType() == DeviceType::NvTensorRtRtx)
956+
// For WebGPU, use device memory only if graph capture is enabled, otherwise use CPU
957+
if (p_device_->GetType() == DeviceType::CUDA || p_device_->GetType() == DeviceType::DML || p_device_->GetType() == DeviceType::NvTensorRtRtx ||
958+
(p_device_->GetType() == DeviceType::WEBGPU && IsGraphCaptureEnabled(config_->model.decoder.session_options)))
957959
p_device_inputs_ = p_device_;
958960
else
959961
p_device_inputs_ = GetDeviceInterface(DeviceType::CPU);

src/models/onnxruntime_api.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,14 @@ struct OrtEnv {
477477

478478
OrtEnv& CreateAndRegisterAllocator(const OrtMemoryInfo& mem_info, const OrtArenaCfg& arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator
479479

480+
/// \brief Copy tensors between devices. Wraps OrtApi::CopyTensors
481+
/// \param src_tensors Array of source OrtValue tensors
482+
/// \param dst_tensors Array of destination OrtValue tensors (must be pre-allocated)
483+
/// \param stream Optional sync stream for asynchronous copy (can be nullptr for synchronous)
484+
void CopyTensors(const std::vector<const OrtValue*>& src_tensors,
485+
const std::vector<OrtValue*>& dst_tensors,
486+
OrtSyncStream* stream = nullptr) const;
487+
480488
std::vector<const OrtEpDevice*> GetEpDevices();
481489

482490
static void operator delete(void* p) { Ort::api->ReleaseEnv(reinterpret_cast<OrtEnv*>(p)); }
@@ -848,6 +856,26 @@ struct OrtShape {
848856
size_t shape_len;
849857
};
850858

859+
/** \brief Wrapper around ::OrtSyncStream
860+
*
861+
* Used for asynchronous operations like CopyTensors.
862+
* Requires ONNX Runtime 1.23.0 or later.
863+
*/
864+
struct OrtSyncStream {
865+
/// \brief Create a sync stream for a specific execution provider device
866+
/// \param ep_device The execution provider device (from OrtEnv::GetEpDevices)
867+
/// \param stream_options Optional stream configuration options
868+
static std::unique_ptr<OrtSyncStream> Create(const OrtEpDevice* ep_device, const OrtKeyValuePairs* stream_options = nullptr);
869+
870+
/// \brief Get the native stream handle (e.g., cudaStream_t for CUDA)
871+
void* GetHandle() const;
872+
873+
static void operator delete(void* p) {
874+
if (p) Ort::api->ReleaseSyncStream(reinterpret_cast<OrtSyncStream*>(p));
875+
}
876+
Ort::Abstract make_abstract;
877+
};
878+
851879
/** \brief Wrapper around ::OrtValue
852880
*
853881
*/

src/models/onnxruntime_inline.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,16 @@ inline std::unique_ptr<OrtMemoryInfo> OrtMemoryInfo::Create(const char* name, Or
256256
return std::unique_ptr<OrtMemoryInfo>{p};
257257
}
258258

259+
inline std::unique_ptr<OrtSyncStream> OrtSyncStream::Create(const OrtEpDevice* ep_device, const OrtKeyValuePairs* stream_options) {
260+
OrtSyncStream* p_stream = nullptr;
261+
Ort::ThrowOnError(Ort::api->CreateSyncStreamForEpDevice(ep_device, stream_options, &p_stream));
262+
return std::unique_ptr<OrtSyncStream>(p_stream);
263+
}
264+
265+
inline void* OrtSyncStream::GetHandle() const {
266+
return Ort::api->SyncStream_GetHandle(const_cast<OrtSyncStream*>(this));
267+
}
268+
259269
inline std::unique_ptr<OrtIoBinding> OrtIoBinding::Create(OrtSession& session) {
260270
OrtIoBinding* p;
261271
Ort::ThrowOnError(Ort::api->CreateIoBinding(&session, &p));
@@ -398,6 +408,15 @@ inline OrtEnv& OrtEnv::CreateAndRegisterAllocator(const OrtMemoryInfo& mem_info,
398408
return *this;
399409
}
400410

411+
inline void OrtEnv::CopyTensors(const std::vector<const OrtValue*>& src_tensors,
412+
const std::vector<OrtValue*>& dst_tensors,
413+
OrtSyncStream* stream) const {
414+
if (src_tensors.size() != dst_tensors.size()) {
415+
throw std::runtime_error("Number of source and destination tensors must match");
416+
}
417+
Ort::ThrowOnError(Ort::api->CopyTensors(this, src_tensors.data(), dst_tensors.data(), stream, src_tensors.size()));
418+
}
419+
401420
inline std::vector<const OrtEpDevice*> OrtEnv::GetEpDevices() {
402421
size_t num_devices = 0;
403422
const OrtEpDevice* const* device_ptrs = nullptr;

src/webgpu/interface.cpp

Lines changed: 114 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,136 @@ const char* device_label = "WebGPU";
1414
struct WebGPUMemory final : DeviceBuffer {
1515
WebGPUMemory(size_t size) : owned_{true} {
1616
size_in_bytes_ = size;
17-
p_cpu_ = p_device_ = static_cast<uint8_t*>(ort_allocator_->Alloc(size_in_bytes_));
17+
p_device_ = static_cast<uint8_t*>(ort_allocator_->Alloc(size_in_bytes_));
1818
}
1919

2020
WebGPUMemory(void* p, size_t size) : owned_{false} {
2121
size_in_bytes_ = size;
22-
p_cpu_ = p_device_ = static_cast<uint8_t*>(p);
22+
p_device_ = static_cast<uint8_t*>(p);
2323
}
2424

2525
~WebGPUMemory() override {
2626
if (owned_)
2727
ort_allocator_->Free(p_device_);
28+
if (p_cpu_)
29+
free(p_cpu_);
2830
}
2931

3032
const char* GetType() const override { return device_label; }
31-
void AllocateCpu() override { throw std::runtime_error("CPU can't access WebGPU memory"); }
32-
void CopyDeviceToCpu() override { throw std::runtime_error("CPU can't access WebGPU memory"); }
33-
void CopyCpuToDevice() override { throw std::runtime_error("CPU can't access WebGPU memory"); }
33+
34+
void AllocateCpu() override {
35+
if (!p_cpu_)
36+
p_cpu_ = static_cast<uint8_t*>(malloc(size_in_bytes_));
37+
}
38+
39+
void CopyDeviceToCpu() override {
40+
if (!ort_allocator_) {
41+
throw std::runtime_error("WebGPU allocator not initialized");
42+
}
43+
44+
AllocateCpu();
45+
46+
// Get WebGPU allocator's memory info
47+
const OrtMemoryInfo* webgpu_mem_info = nullptr;
48+
Ort::ThrowOnError(Ort::api->AllocatorGetInfo(ort_allocator_, &webgpu_mem_info));
49+
50+
// Create source tensor (WebGPU device memory) - treat as 1D uint8 array
51+
int64_t shape_val = static_cast<int64_t>(size_in_bytes_);
52+
std::span<const int64_t> shape{&shape_val, 1};
53+
auto src_tensor = OrtValue::CreateTensor(*webgpu_mem_info, p_device_, size_in_bytes_, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
54+
55+
// Create CPU memory info and destination tensor
56+
auto cpu_mem_info = OrtMemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
57+
auto dst_tensor = OrtValue::CreateTensor(*cpu_mem_info, p_cpu_, size_in_bytes_, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
58+
59+
// Use ORT C API's CopyTensors (synchronous copy, stream = nullptr)
60+
OrtValue* src_ptrs[] = {src_tensor.get()};
61+
OrtValue* dst_ptrs[] = {dst_tensor.get()};
62+
Ort::ThrowOnError(Ort::api->CopyTensors(&GetOrtEnv(), src_ptrs, dst_ptrs, nullptr, 1));
63+
}
64+
65+
void CopyCpuToDevice() override {
66+
if (!ort_allocator_) {
67+
throw std::runtime_error("WebGPU allocator not initialized");
68+
}
69+
assert(p_cpu_);
70+
71+
// Get WebGPU allocator's memory info
72+
const OrtMemoryInfo* webgpu_mem_info = nullptr;
73+
Ort::ThrowOnError(Ort::api->AllocatorGetInfo(ort_allocator_, &webgpu_mem_info));
74+
75+
// Create source tensor (CPU memory) - treat as 1D uint8 array
76+
int64_t shape_val = static_cast<int64_t>(size_in_bytes_);
77+
std::span<const int64_t> shape{&shape_val, 1};
78+
auto cpu_mem_info = OrtMemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
79+
auto src_tensor = OrtValue::CreateTensor(*cpu_mem_info, p_cpu_, size_in_bytes_, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
80+
81+
// Create destination tensor (WebGPU device memory)
82+
auto dst_tensor = OrtValue::CreateTensor(*webgpu_mem_info, p_device_, size_in_bytes_, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
83+
84+
// Use ORT C API's CopyTensors (synchronous copy, stream = nullptr)
85+
OrtValue* src_ptrs[] = {src_tensor.get()};
86+
OrtValue* dst_ptrs[] = {dst_tensor.get()};
87+
Ort::ThrowOnError(Ort::api->CopyTensors(&GetOrtEnv(), src_ptrs, dst_ptrs, nullptr, 1));
88+
}
89+
3490
void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override {
35-
throw std::runtime_error("CPU can't access WebGPU memory");
91+
if (!ort_allocator_) {
92+
throw std::runtime_error("WebGPU allocator not initialized");
93+
}
94+
95+
// Fast path: WebGPU-to-WebGPU copy with zero offsets
96+
// NOTE: p_device_ is a WGPUBuffer handle (cast to uint8_t*), not a memory pointer.
97+
// We cannot use pointer arithmetic (p_device_ + offset) to create sub-buffer views.
98+
// OrtValue::CreateTensor expects the actual buffer handle, not an offset pointer.
99+
if (source.GetType() == device_label && begin_source == 0 && begin_dest == 0) {
100+
// Get WebGPU allocator's memory info
101+
const OrtMemoryInfo* webgpu_mem_info = nullptr;
102+
Ort::ThrowOnError(Ort::api->AllocatorGetInfo(ort_allocator_, &webgpu_mem_info));
103+
104+
// Full buffer copy using CopyTensors (no offsets)
105+
int64_t shape_val = static_cast<int64_t>(size_in_bytes);
106+
std::span<const int64_t> shape{&shape_val, 1};
107+
auto src_tensor = OrtValue::CreateTensor(*webgpu_mem_info, source.p_device_, size_in_bytes, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
108+
auto dst_tensor = OrtValue::CreateTensor(*webgpu_mem_info, p_device_, size_in_bytes, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
109+
110+
// Use ORT C API's CopyTensors for GPU-to-GPU copy
111+
OrtValue* src_ptrs[] = {src_tensor.get()};
112+
OrtValue* dst_ptrs[] = {dst_tensor.get()};
113+
Ort::ThrowOnError(Ort::api->CopyTensors(&GetOrtEnv(), src_ptrs, dst_ptrs, nullptr, 1));
114+
} else {
115+
// Fallback: Copy through CPU for:
116+
// - WebGPU-to-WebGPU copies with non-zero offsets (buffer handles don't support offset arithmetic)
117+
// - Cross-device copies (e.g., CPU to WebGPU or vice versa)
118+
CopyThroughCpu(*this, begin_dest, source, begin_source, size_in_bytes);
119+
}
36120
}
37121

38122
void Zero() override {
39-
throw std::runtime_error("Zeroing not implemented for WebGPU memory");
123+
if (!ort_allocator_) {
124+
throw std::runtime_error("WebGPU allocator not initialized");
125+
}
126+
127+
// Allocate zeroed CPU memory
128+
std::vector<uint8_t> zero_buffer(size_in_bytes_, 0);
129+
130+
// Get WebGPU allocator's memory info
131+
const OrtMemoryInfo* webgpu_mem_info = nullptr;
132+
Ort::ThrowOnError(Ort::api->AllocatorGetInfo(ort_allocator_, &webgpu_mem_info));
133+
134+
// Create source tensor (CPU memory with zeros) - treat as 1D uint8 array
135+
int64_t shape_val = static_cast<int64_t>(size_in_bytes_);
136+
std::span<const int64_t> shape{&shape_val, 1};
137+
auto cpu_mem_info = OrtMemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
138+
auto src_tensor = OrtValue::CreateTensor(*cpu_mem_info, zero_buffer.data(), size_in_bytes_, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
139+
140+
// Create destination tensor (WebGPU device memory)
141+
auto dst_tensor = OrtValue::CreateTensor(*webgpu_mem_info, p_device_, size_in_bytes_, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
142+
143+
// Use ORT C API's CopyTensors to copy zeros to GPU (synchronous copy, stream = nullptr)
144+
OrtValue* src_ptrs[] = {src_tensor.get()};
145+
OrtValue* dst_ptrs[] = {dst_tensor.get()};
146+
Ort::ThrowOnError(Ort::api->CopyTensors(&GetOrtEnv(), src_ptrs, dst_ptrs, nullptr, 1));
40147
}
41148

42149
bool owned_;
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
onnxruntime==1.22.0
1+
onnxruntime==1.23.0

0 commit comments

Comments
 (0)