@@ -14,29 +14,136 @@ const char* device_label = "WebGPU";
1414struct WebGPUMemory final : DeviceBuffer {
1515 WebGPUMemory (size_t size) : owned_{true } {
1616 size_in_bytes_ = size;
17- p_cpu_ = p_device_ = static_cast <uint8_t *>(ort_allocator_->Alloc (size_in_bytes_));
17+ p_device_ = static_cast <uint8_t *>(ort_allocator_->Alloc (size_in_bytes_));
1818 }
1919
2020 WebGPUMemory (void * p, size_t size) : owned_{false } {
2121 size_in_bytes_ = size;
22- p_cpu_ = p_device_ = static_cast <uint8_t *>(p);
22+ p_device_ = static_cast <uint8_t *>(p);
2323 }
2424
2525 ~WebGPUMemory () override {
2626 if (owned_)
2727 ort_allocator_->Free (p_device_);
28+ if (p_cpu_)
29+ free (p_cpu_);
2830 }
2931
3032 const char * GetType () const override { return device_label; }
31- void AllocateCpu () override { throw std::runtime_error (" CPU can't access WebGPU memory" ); }
32- void CopyDeviceToCpu () override { throw std::runtime_error (" CPU can't access WebGPU memory" ); }
33- void CopyCpuToDevice () override { throw std::runtime_error (" CPU can't access WebGPU memory" ); }
33+
34+ void AllocateCpu () override {
35+ if (!p_cpu_)
36+ p_cpu_ = static_cast <uint8_t *>(malloc (size_in_bytes_));
37+ }
38+
39+ void CopyDeviceToCpu () override {
40+ if (!ort_allocator_) {
41+ throw std::runtime_error (" WebGPU allocator not initialized" );
42+ }
43+
44+ AllocateCpu ();
45+
46+ // Get WebGPU allocator's memory info
47+ const OrtMemoryInfo* webgpu_mem_info = nullptr ;
48+ Ort::ThrowOnError (Ort::api->AllocatorGetInfo (ort_allocator_, &webgpu_mem_info));
49+
50+ // Create source tensor (WebGPU device memory) - treat as 1D uint8 array
51+ int64_t shape_val = static_cast <int64_t >(size_in_bytes_);
52+ std::span<const int64_t > shape{&shape_val, 1 };
53+ auto src_tensor = OrtValue::CreateTensor (*webgpu_mem_info, p_device_, size_in_bytes_, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
54+
55+ // Create CPU memory info and destination tensor
56+ auto cpu_mem_info = OrtMemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeDefault);
57+ auto dst_tensor = OrtValue::CreateTensor (*cpu_mem_info, p_cpu_, size_in_bytes_, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
58+
59+ // Use ORT C API's CopyTensors (synchronous copy, stream = nullptr)
60+ OrtValue* src_ptrs[] = {src_tensor.get ()};
61+ OrtValue* dst_ptrs[] = {dst_tensor.get ()};
62+ Ort::ThrowOnError (Ort::api->CopyTensors (&GetOrtEnv (), src_ptrs, dst_ptrs, nullptr , 1 ));
63+ }
64+
65+ void CopyCpuToDevice () override {
66+ if (!ort_allocator_) {
67+ throw std::runtime_error (" WebGPU allocator not initialized" );
68+ }
69+ assert (p_cpu_);
70+
71+ // Get WebGPU allocator's memory info
72+ const OrtMemoryInfo* webgpu_mem_info = nullptr ;
73+ Ort::ThrowOnError (Ort::api->AllocatorGetInfo (ort_allocator_, &webgpu_mem_info));
74+
75+ // Create source tensor (CPU memory) - treat as 1D uint8 array
76+ int64_t shape_val = static_cast <int64_t >(size_in_bytes_);
77+ std::span<const int64_t > shape{&shape_val, 1 };
78+ auto cpu_mem_info = OrtMemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeDefault);
79+ auto src_tensor = OrtValue::CreateTensor (*cpu_mem_info, p_cpu_, size_in_bytes_, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
80+
81+ // Create destination tensor (WebGPU device memory)
82+ auto dst_tensor = OrtValue::CreateTensor (*webgpu_mem_info, p_device_, size_in_bytes_, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
83+
84+ // Use ORT C API's CopyTensors (synchronous copy, stream = nullptr)
85+ OrtValue* src_ptrs[] = {src_tensor.get ()};
86+ OrtValue* dst_ptrs[] = {dst_tensor.get ()};
87+ Ort::ThrowOnError (Ort::api->CopyTensors (&GetOrtEnv (), src_ptrs, dst_ptrs, nullptr , 1 ));
88+ }
89+
3490 void CopyFrom (size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override {
35- throw std::runtime_error (" CPU can't access WebGPU memory" );
91+ if (!ort_allocator_) {
92+ throw std::runtime_error (" WebGPU allocator not initialized" );
93+ }
94+
95+ // Fast path: WebGPU-to-WebGPU copy with zero offsets
96+ // NOTE: p_device_ is a WGPUBuffer handle (cast to uint8_t*), not a memory pointer.
97+ // We cannot use pointer arithmetic (p_device_ + offset) to create sub-buffer views.
98+ // OrtValue::CreateTensor expects the actual buffer handle, not an offset pointer.
99+ if (source.GetType () == device_label && begin_source == 0 && begin_dest == 0 ) {
100+ // Get WebGPU allocator's memory info
101+ const OrtMemoryInfo* webgpu_mem_info = nullptr ;
102+ Ort::ThrowOnError (Ort::api->AllocatorGetInfo (ort_allocator_, &webgpu_mem_info));
103+
104+ // Full buffer copy using CopyTensors (no offsets)
105+ int64_t shape_val = static_cast <int64_t >(size_in_bytes);
106+ std::span<const int64_t > shape{&shape_val, 1 };
107+ auto src_tensor = OrtValue::CreateTensor (*webgpu_mem_info, source.p_device_ , size_in_bytes, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
108+ auto dst_tensor = OrtValue::CreateTensor (*webgpu_mem_info, p_device_, size_in_bytes, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
109+
110+ // Use ORT C API's CopyTensors for GPU-to-GPU copy
111+ OrtValue* src_ptrs[] = {src_tensor.get ()};
112+ OrtValue* dst_ptrs[] = {dst_tensor.get ()};
113+ Ort::ThrowOnError (Ort::api->CopyTensors (&GetOrtEnv (), src_ptrs, dst_ptrs, nullptr , 1 ));
114+ } else {
115+ // Fallback: Copy through CPU for:
116+ // - WebGPU-to-WebGPU copies with non-zero offsets (buffer handles don't support offset arithmetic)
117+ // - Cross-device copies (e.g., CPU to WebGPU or vice versa)
118+ CopyThroughCpu (*this , begin_dest, source, begin_source, size_in_bytes);
119+ }
36120 }
37121
38122 void Zero () override {
39- throw std::runtime_error (" Zeroing not implemented for WebGPU memory" );
123+ if (!ort_allocator_) {
124+ throw std::runtime_error (" WebGPU allocator not initialized" );
125+ }
126+
127+ // Allocate zeroed CPU memory
128+ std::vector<uint8_t > zero_buffer (size_in_bytes_, 0 );
129+
130+ // Get WebGPU allocator's memory info
131+ const OrtMemoryInfo* webgpu_mem_info = nullptr ;
132+ Ort::ThrowOnError (Ort::api->AllocatorGetInfo (ort_allocator_, &webgpu_mem_info));
133+
134+ // Create source tensor (CPU memory with zeros) - treat as 1D uint8 array
135+ int64_t shape_val = static_cast <int64_t >(size_in_bytes_);
136+ std::span<const int64_t > shape{&shape_val, 1 };
137+ auto cpu_mem_info = OrtMemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeDefault);
138+ auto src_tensor = OrtValue::CreateTensor (*cpu_mem_info, zero_buffer.data (), size_in_bytes_, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
139+
140+ // Create destination tensor (WebGPU device memory)
141+ auto dst_tensor = OrtValue::CreateTensor (*webgpu_mem_info, p_device_, size_in_bytes_, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
142+
143+ // Use ORT C API's CopyTensors to copy zeros to GPU (synchronous copy, stream = nullptr)
144+ OrtValue* src_ptrs[] = {src_tensor.get ()};
145+ OrtValue* dst_ptrs[] = {dst_tensor.get ()};
146+ Ort::ThrowOnError (Ort::api->CopyTensors (&GetOrtEnv (), src_ptrs, dst_ptrs, nullptr , 1 ));
40147 }
41148
42149 bool owned_;
0 commit comments