Skip to content
Open
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,11 @@ WebGpuContext& WebGpuContextFactory::GetContext(int context_id) {
return *it->second.context;
}

bool WebGpuContextFactory::HasContext(int context_id) {
std::lock_guard<std::mutex> lock(mutex_);
return contexts_.find(context_id) != contexts_.end();
}

void WebGpuContextFactory::ReleaseContext(int context_id) {
std::lock_guard<std::mutex> lock(mutex_);

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class WebGpuContextFactory {

static WebGpuContext& CreateContext(const WebGpuContextConfig& config);
static WebGpuContext& GetContext(int context_id);
static bool HasContext(int context_id);

static void ReleaseContext(int context_id);

Expand Down
168 changes: 168 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,48 @@
#include "core/session/ort_apis.h"

#include "core/providers/webgpu/webgpu_provider_options.h"
#include "core/providers/webgpu/data_transfer.h"
using namespace onnxruntime::webgpu::options;

namespace onnxruntime {
// Helper to get default context config, buffer cache config, backend type, and enable_pix_capture
struct WebGpuContextParams {
webgpu::WebGpuContextConfig context_config;
webgpu::WebGpuBufferCacheConfig buffer_cache_config;
int backend_type;
bool enable_pix_capture;
};

static WebGpuContextParams GetDefaultWebGpuContextParams() {
WebGpuContextParams params;
params.context_config.context_id = 0;
params.context_config.instance = nullptr;
params.context_config.device = nullptr;
params.context_config.dawn_proc_table = nullptr;
params.context_config.validation_mode = webgpu::ValidationMode::Disabled;
params.context_config.preserve_device = false;
params.context_config.max_storage_buffer_binding_size = 0;
params.context_config.power_preference = static_cast<int>(WGPUPowerPreference_HighPerformance);

params.buffer_cache_config.storage.mode = webgpu::BufferCacheMode::Bucket;
params.buffer_cache_config.uniform.mode = webgpu::BufferCacheMode::Simple;
params.buffer_cache_config.query_resolve.mode = webgpu::BufferCacheMode::Disabled;
params.buffer_cache_config.default_entry.mode = webgpu::BufferCacheMode::Disabled;

#ifdef _WIN32
#if defined(DAWN_ENABLE_D3D12)
params.backend_type = static_cast<int>(WGPUBackendType_D3D12);
#elif defined(DAWN_ENABLE_VULKAN)
params.backend_type = static_cast<int>(WGPUBackendType_Vulkan);
#else
params.backend_type = static_cast<int>(WGPUBackendType_D3D12);
#endif
#else
params.backend_type = 0;
#endif
params.enable_pix_capture = false;
return params;
}

struct WebGpuProviderFactory : IExecutionProviderFactory {
WebGpuProviderFactory(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderConfig&& webgpu_ep_config)
Expand Down Expand Up @@ -291,4 +330,133 @@
return std::make_shared<WebGpuProviderFactory>(context_id, context, std::move(webgpu_ep_config));
}

// WebGPU DataTransfer implementation wrapper for the C API with lazy initialization
struct WebGpuDataTransferImpl : OrtDataTransferImpl {
WebGpuDataTransferImpl(const OrtApi& ort_api_in)

Check warning on line 335 in onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [4] Raw Output: onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc:335: Single-parameter constructors should be marked explicit. [runtime/explicit] [4]
: ort_api{ort_api_in},
ep_api{*ort_api_in.GetEpApi()},
data_transfer_{nullptr},
context_id_{-1} {
ort_version_supported = ORT_API_VERSION;
CanCopy = CanCopyImpl;
CopyTensors = CopyTensorsImpl;
Release = ReleaseImpl;
}

static bool CanCopyImpl(const OrtDataTransferImpl* this_ptr,
const OrtMemoryDevice* src_memory_device,
const OrtMemoryDevice* dst_memory_device) noexcept {
const auto& impl = *static_cast<const WebGpuDataTransferImpl*>(this_ptr);
OrtMemoryInfoDeviceType src_type = impl.ep_api.MemoryDevice_GetDeviceType(src_memory_device);
OrtMemoryInfoDeviceType dst_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_memory_device);

// Check if at least one device is GPU
bool has_gpu = (src_type == OrtMemoryInfoDeviceType_GPU) || (dst_type == OrtMemoryInfoDeviceType_GPU);
if (!has_gpu) {
return false;
}

// WebGPU uses vendor ID 0 (VendorIds::NONE). Only handle GPU devices with vendor ID 0.
// This prevents attempting to copy data for other EPs' fake GPU devices (e.g., example EP with vendor 0xBE57)
if (src_type == OrtMemoryInfoDeviceType_GPU) {
uint32_t src_vendor = impl.ep_api.MemoryDevice_GetVendorId(src_memory_device);
if (src_vendor != 0) {
return false; // Not a WebGPU device
}
}

if (dst_type == OrtMemoryInfoDeviceType_GPU) {
uint32_t dst_vendor = impl.ep_api.MemoryDevice_GetVendorId(dst_memory_device);
if (dst_vendor != 0) {
return false; // Not a WebGPU device
}
}

// WebGPU supports GPU<->GPU, GPU<->CPU copies (where GPU has vendor ID 0)
return (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_GPU) ||
(src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_CPU) ||
(src_type == OrtMemoryInfoDeviceType_CPU && dst_type == OrtMemoryInfoDeviceType_GPU);
}

static OrtStatus* CopyTensorsImpl(OrtDataTransferImpl* this_ptr,
const OrtValue** src_tensors,
OrtValue** dst_tensors,
OrtSyncStream** /*streams*/,
size_t num_tensors) noexcept {
auto& impl = *static_cast<WebGpuDataTransferImpl*>(this_ptr);

if (num_tensors == 0) {
return nullptr;
}

// Lazy initialization: Get context_id from the first GPU tensor's device
int context_id = 0; // Default to context_id 0
bool found_gpu_tensor = false;

// Check both src_tensors and dst_tensors to find the first GPU tensor
for (size_t idx = 0; idx < num_tensors && !found_gpu_tensor; ++idx) {
// Check source tensor
const OrtMemoryDevice* src_device = impl.ep_api.Value_GetMemoryDevice(src_tensors[idx]);
OrtMemoryInfoDeviceType src_device_type = impl.ep_api.MemoryDevice_GetDeviceType(src_device);
if (src_device_type == OrtMemoryInfoDeviceType_GPU) {
context_id = static_cast<int>(impl.ep_api.MemoryDevice_GetDeviceId(src_device));
found_gpu_tensor = true;
break;
}

// Check destination tensor
const OrtMemoryDevice* dst_device = impl.ep_api.Value_GetMemoryDevice(dst_tensors[idx]);
OrtMemoryInfoDeviceType dst_device_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_device);
if (dst_device_type == OrtMemoryInfoDeviceType_GPU) {
context_id = static_cast<int>(impl.ep_api.MemoryDevice_GetDeviceId(dst_device));
found_gpu_tensor = true;
break;
}
}

// Initialize data_transfer if not already done or if context_id changed
if (impl.data_transfer_ == nullptr || impl.context_id_ != context_id) {
impl.context_id_ = context_id;

// Check if context exists, create a default one if it doesn't
webgpu::WebGpuContext* context_ptr = nullptr;
if (webgpu::WebGpuContextFactory::HasContext(context_id)) {
context_ptr = &webgpu::WebGpuContextFactory::GetContext(context_id);
} else {
WebGpuContextParams params = GetDefaultWebGpuContextParams();
params.context_config.context_id = context_id;
context_ptr = &webgpu::WebGpuContextFactory::CreateContext(params.context_config);
context_ptr->Initialize(params.buffer_cache_config, params.backend_type, params.enable_pix_capture);
}

// Create the DataTransfer instance
impl.data_transfer_ = std::make_unique<webgpu::DataTransfer>(context_ptr->BufferManager());
}

// Now perform the actual tensor copy
for (size_t idx = 0; idx < num_tensors; ++idx) {
const OrtValue* src_tensor = src_tensors[idx];
OrtValue* dst_tensor = dst_tensors[idx];
auto status = impl.data_transfer_->CopyTensor(src_tensor->Get<Tensor>(), *dst_tensor->GetMutable<Tensor>());
if (!status.IsOK()) {
return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, status.ErrorMessage().c_str());
}
}
return nullptr;
}

static void ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept {
delete static_cast<WebGpuDataTransferImpl*>(this_ptr);
}

const OrtApi& ort_api;
const OrtEpApi& ep_api;
std::unique_ptr<webgpu::DataTransfer> data_transfer_; // Lazy-initialized

Check warning on line 454 in onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc:454: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
int context_id_; // Track which context we're using
};

OrtDataTransferImpl* OrtWebGpuCreateDataTransfer() {
return new WebGpuDataTransferImpl(*OrtApis::GetApi(ORT_API_VERSION));
}

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,18 @@

#include "core/providers/webgpu/webgpu_provider_options.h"

struct OrtDataTransferImpl;

namespace onnxruntime {
struct ConfigOptions;

struct WebGpuProviderFactoryCreator {
static std::shared_ptr<IExecutionProviderFactory> Create(const ConfigOptions& config_options);
};

// C API to create data transfer for WebGPU EP with lazy initialization
// Context will be determined from tensors during the first CopyTensors call
// Caller takes ownership of the returned OrtDataTransferImpl*
OrtDataTransferImpl* OrtWebGpuCreateDataTransfer();

} // namespace onnxruntime
17 changes: 5 additions & 12 deletions onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,13 @@ OrtStatus* WebGpuEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* co
return nullptr;
}

/* TODO: Implement CreateAllocator and CreateDataTransfer to support shared allocators and data transfer outside of
an InferenceSession.
OrtStatus* WebGpuEpFactory::CreateAllocator(const OrtMemoryInfo* memory_info,
const OrtKeyValuePairs* allocator_options,
OrtAllocator** allocator) noexcept override {
*allocator = device_allocators[memory_info->device.Id()].get();
}

OrtStatus* WebGpuEpFactory::CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override {
// TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors.
*data_transfer = nullptr;
OrtStatus* WebGpuEpFactory::CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept {
// Call the WebGPU provider's C API to create the data transfer
// This is implemented in the WebGPU provider backend which has access to WebGPU headers
*data_transfer = OrtWebGpuCreateDataTransfer();
return nullptr;
}
*/

} // namespace onnxruntime

#endif // USE_WEBGPU
2 changes: 2 additions & 0 deletions onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class WebGpuEpFactory : public EpFactoryInternalImpl {
const OrtSessionOptions* session_options,
const OrtLogger* session_logger,
std::unique_ptr<IExecutionProvider>* ep) noexcept override;

OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept override;
};
} // namespace onnxruntime

Expand Down
Loading