diff --git a/Compiling.md b/Compiling.md index 74932a1f4..95003ab6a 100644 --- a/Compiling.md +++ b/Compiling.md @@ -33,6 +33,8 @@ As also mentioned in the instructions below but repeated here for visibility, if * If using the OpenCL backend, a modern GPU that supports OpenCL 1.2 or greater, or else something like [this](https://software.intel.com/en-us/opencl-sdk) for CPU. But if using CPU, Eigen should be better. * If using the CUDA backend, CUDA 11 or later and a compatible version of CUDNN based on your CUDA version (https://developer.nvidia.com/cuda-toolkit) (https://developer.nvidia.com/cudnn) and a GPU capable of supporting them. * If using the TensorRT backend, in addition to a compatible CUDA Toolkit (https://developer.nvidia.com/cuda-toolkit), you also need TensorRT (https://developer.nvidia.com/tensorrt) that is at least version 8.5. + * If using the ROCm backend, ROCm 6.4 or later and a GPU capable of supporting them. More information about installation(https://rocm.docs.amd.com/projects/install-on-linux/en/latest/) and please install all possible ROCm developer packages, instead of just ROCm runtime packages. + * If using the MIGraphX backend, ROCm 7.0 or later with MIGraphX library installed (e.g. `sudo apt install migraphx` via the ROCm package repo). * If using the Eigen backend, Eigen3. With Debian packages, (i.e. apt or apt-get), this should be `libeigen3-dev`. * zlib, libzip. With Debian packages (i.e. apt or apt-get), these should be `zlib1g-dev`, `libzip-dev`. * If you want to do self-play training and research, probably Google perftools `libgoogle-perftools-dev` for TCMalloc or some other better malloc implementation. For unknown reasons, the allocation pattern in self-play with large numbers of threads and parallel games causes a lot of memory fragmentation under glibc malloc that will eventually run your machine out of memory, but better mallocs handle it fine. @@ -41,7 +43,7 @@ As also mentioned in the instructions below but repeated here for visibility, if * `git clone https://github.com/lightvector/KataGo.git` * Compile using CMake and make in the cpp directory: * `cd KataGo/cpp` - * `cmake . -DUSE_BACKEND=OPENCL` or `cmake . -DUSE_BACKEND=CUDA` or `cmake . -DUSE_BACKEND=TENSORRT` or `cmake . -DUSE_BACKEND=EIGEN` depending on which backend you want. + * `cmake . -DUSE_BACKEND=OPENCL` or `cmake . -DUSE_BACKEND=CUDA` or `cmake . -DUSE_BACKEND=TENSORRT` or `cmake . -DUSE_BACKEND=EIGEN` or `cmake . -DUSE_BACKEND=ROCM` or `cmake . -DUSE_BACKEND=MIGRAPHX` depending on which backend you want. * Specify also `-DUSE_TCMALLOC=1` if using TCMalloc. * Compiling will also call git commands to embed the git hash into the compiled executable, specify also `-DNO_GIT_REVISION=1` to disable it if this is causing issues for you. * Specify `-DUSE_AVX2=1` to also compile Eigen with AVX2 and FMA support, which will make it incompatible with old CPUs but much faster. (If you want to go further, you can also add `-DCMAKE_CXX_FLAGS='-march=native'` which will specialize to precisely your machine's CPU, but the exe might not run on other machines at all). @@ -54,6 +56,30 @@ As also mentioned in the instructions below but repeated here for visibility, if * You will probably want to edit `configs/gtp_example.cfg` (see "Tuning for Performance" above). * If using OpenCL, you will want to verify that KataGo is picking up the correct device when you run it (e.g. some systems may have both an Intel CPU OpenCL and GPU OpenCL, if KataGo appears to pick the wrong one, you can correct this by specifying `openclGpuToUse` in `configs/gtp_example.cfg`). + * **ROCm backend (Linux) — additional notes:** + * Install ROCm following the [official guide](https://rocm.docs.amd.com/en/7.12.0-preview/install/rocm.html). Install the full developer stack (not just runtime): `sudo apt install rocm-dev miopen-hip rocblas hipblas`. + * Build: + ``` + cd KataGo/cpp + mkdir build && cd build + cmake .. -DUSE_BACKEND=ROCM -DCMAKE_BUILD_TYPE=Release + make -j$(nproc) + ``` + * GPU architecture is auto-detected via `amdgpu-arch`. If auto-detection fails, specify manually: `-DCMAKE_HIP_ARCHITECTURES=gfx1100` (replace with your GPU's gfx target). + * On first run, MIOpen will search for optimal convolution algorithms for your specific GPU and network size. This may take up to a minute and results are cached in `~/.config/miopen/` for subsequent runs. + + * **MIGraphX backend (Linux) — additional notes:** + * Requires ROCm 7.0+ with MIGraphX installed. Install via: `sudo apt install migraphx`. + * Build: + ``` + cd KataGo/cpp + mkdir build && cd build + cmake .. -DUSE_BACKEND=MIGRAPHX -DCMAKE_BUILD_TYPE=Release + make -j$(nproc) + ``` + * On first launch, MIGraphX compiles and caches GPU programs for each batch size (4, 8, 16, 24, 32, 40, 64 up to `maxBatchSize`) in `~/.katago/migraphxcache/`. This initial compilation may take several minutes but subsequent launches load from cache instantly. + * MIGraphX may offer better GPU utilization and throughput than the ROCm/MIOpen backend on some workloads due to whole-graph operator fusion. + ## Windows * TLDR: * Building from source on Windows is actually a bit tricky, depending on what version you're building, there's not necessarily a super-fast way. @@ -117,6 +143,42 @@ As also mentioned in the instructions below but repeated here for visibility, if * You will probably want to edit `configs/gtp_example.cfg` (see "Tuning for Performance" above). * If using OpenCL, you will want to verify that KataGo is picking up the correct device (e.g. some systems may have both an Intel CPU OpenCL and GPU OpenCL, if KataGo appears to pick the wrong one, you can correct this by specifying `openclGpuToUse` in `configs/gtp_example.cfg`). + * **ROCm backend (Windows) — building via AMD TheRock:** + * The ROCm (MIOpen) backend supports Windows via [AMD TheRock](https://github.com/ROCm/TheRock) (tested with TheRock 7.12.0 / ROCm 7.2.0, RX 7900 XTX / gfx1100). + * **Prerequisites:** + * Install ROCm following the [official guide](https://rocm.docs.amd.com/en/7.12.0-preview/install/rocm.html). For Windows, download [AMD TheRock](https://github.com/ROCm/TheRock) and extract to e.g. `C:\TheRock\build`. + * Install **Visual Studio 2026 Build Tools** or **Visual Studio 2026 Community** with the "Desktop development with C++" workload. This provides the MSVC toolchain and Windows SDK required by the HIP compiler. + * Install [Ninja](https://ninja-build.org) build tool: `winget install Ninja-build.Ninja`. + * Set the following **system environment variables** (via System Properties → Advanced → Environment Variables): + ``` + HIP_PATH=C:/TheRock/build + HIP_PLATFORM=amd + HIP_DEVICE_LIB_PATH=C:/TheRock/build/lib/llvm/amdgcn/bitcode + LLVM_PATH=C:/TheRock/build/lib/llvm + ``` + * Add to system `PATH`: + ``` + C:\TheRock\build\bin + C:\TheRock\build\lib\llvm\bin + ``` + * Reboot after setting environment variables so they take effect system-wide. + * **Build** (from a terminal with the above env vars active): + ``` + cd KataGo/cpp + mkdir build + cd build + cmake .. -G Ninja -DUSE_BACKEND=ROCM -DCMAKE_BUILD_TYPE=Release + ninja -j $env:NUMBER_OF_PROCESSORS + ``` + No additional `-D` flags are needed — `CMakeLists.txt` automatically detects the HIP/clang compiler, GPU architecture (via `amdgpu-arch.exe`), Windows SDK include paths, and zlib from `HIP_PATH`. + * **Runtime DLL setup** — copy the following next to `katago.exe`: + * `amdhip64_7.dll` — **required**: must be copied from `D:\TheRock\build\bin\` to override the incompatible version that AMD GPU drivers install into `C:\Windows\System32\`. + * All other ROCm DLLs (`MIOpen.dll`, `hipblas.dll`, `rocblas.dll`, `hiprtc0702.dll`, `amd_comgr0702.dll`, `libhipblaslt.dll`, `amdocl64.dll`) are found automatically from `D:\TheRock\build\bin\` via `PATH` — no need to copy them. + * If `rocblas.dll` is copied, also copy the `rocblas\library\` directory alongside it (rocBLAS looks for its kernel files relative to its own DLL location). + * MSVC runtime DLLs (`msvcp140.dll`, `vcruntime140.dll`, etc.) are in `C:\Windows\System32\` on any machine with the Visual C++ Redistributable installed. + * **First-run note:** MIOpen will search for optimal convolution algorithms on the first run. This may take 45+ seconds per network configuration and results are cached in `%USERPROFILE%\.miopen\` for subsequent runs. Do not terminate the process during this initial tuning. + * **Performance note:** GPU utilization on Windows may be somewhat lower than on Linux due to the Windows Driver Model (WDDM) adding overhead to GPU kernel submissions. This is a known limitation of ROCm on Windows. + ## MacOS * TLDR: ``` diff --git a/README.md b/README.md index ce7e87b97..b17532fed 100644 --- a/README.md +++ b/README.md @@ -1,27 +1,30 @@ # KataGo -* [Overview](#overview) -* [Training History and Research](#training-history-and-research) -* [Where To Download Stuff](#where-to-download-stuff) -* [Setting Up and Running KataGo](#setting-up-and-running-katago) - * [GUIs](#guis) - * [Windows and Linux](#windows-and-linux) - * [MacOS](#macos) - * [OpenCL vs CUDA vs TensorRT vs Eigen](#opencl-vs-cuda-vs-tensorrt-vs-eigen) - * [How To Use](#how-to-use) - * [Tuning for Performance](#tuning-for-performance) - * [Common Questions and Issues](#common-questions-and-issues) - * [Issues with specific GPUs or GPU drivers](#issues-with-specific-gpus-or-gpu-drivers) - * [Common Problems](#common-problems) - * [Other Questions](#other-questions) -* [Features for Developers](#features-for-developers) - * [GTP Extensions](#gtp-extensions) - * [Analysis Engine](#analysis-engine) -* [Compiling KataGo](#compiling-katago) -* [Source Code Overview](#source-code-overview) -* [Selfplay Training](#selfplay-training) -* [Contributors](#contributors) -* [License](#license) +- [KataGo](#katago) + - [Overview](#overview) + - [Training History and Research and Docs](#training-history-and-research-and-docs) + - [Where To Download Stuff](#where-to-download-stuff) + - [Setting Up and Running KataGo](#setting-up-and-running-katago) + - [GUIs](#guis) + - [Windows and Linux](#windows-and-linux) + - [MacOS](#macos) + - [OpenCL vs CUDA vs TensorRT vs ROCm vs MIGraphX vs Eigen](#opencl-vs-cuda-vs-tensorrt-vs-rocm-vs-migraphx-vs-eigen) + - [How To Use](#how-to-use) + - [Human-style Play and Analysis](#human-style-play-and-analysis) + - [Other Commands:](#other-commands) + - [Tuning for Performance](#tuning-for-performance) + - [Common Questions and Issues](#common-questions-and-issues) + - [Issues with specific GPUs or GPU drivers](#issues-with-specific-gpus-or-gpu-drivers) + - [Common Problems](#common-problems) + - [Other Questions](#other-questions) + - [Features for Developers](#features-for-developers) + - [GTP Extensions:](#gtp-extensions) + - [Analysis Engine:](#analysis-engine) + - [Compiling KataGo](#compiling-katago) + - [Source Code Overview:](#source-code-overview) + - [Selfplay Training:](#selfplay-training) + - [Contributors](#contributors) + - [License](#license) ## Overview @@ -84,8 +87,8 @@ The community also provides KataGo packages for [Homebrew](https://brew.sh) on M Use `brew install katago`. The latest config files and networks are installed in KataGo's `share` directory. Find them via `brew list --verbose katago`. A basic way to run katago will be `katago gtp -config $(brew list --verbose katago | grep 'gtp.*\.cfg') -model $(brew list --verbose katago | grep .gz | head -1)`. You should choose the Network according to the release notes here and customize the provided example config as with every other way of installing KataGo. -### OpenCL vs CUDA vs TensorRT vs Eigen -KataGo has four backends, OpenCL (GPU), CUDA (GPU), TensorRT (GPU), and Eigen (CPU). +### OpenCL vs CUDA vs TensorRT vs ROCm vs MIGraphX vs Eigen +KataGo has six backends, OpenCL (GPU), CUDA (GPU), TensorRT (GPU), ROCm (GPU), MIGraphX (GPU) and Eigen (CPU). The quick summary is: * **To easily get something working, try OpenCL if you have any good or decent GPU.** @@ -93,11 +96,15 @@ The quick summary is: * Use Eigen with AVX2 if you don't have a GPU or if your GPU is too old/weak to work with OpenCL, and you just want a plain CPU KataGo. * Use Eigen without AVX2 if your CPU is old or on a low-end device that doesn't support AVX2. * The CUDA backend can work for NVIDIA GPUs with CUDA+CUDNN installed but is likely worse than TensorRT. + * The ROCm backend can work for AMD GPUs with ROCm+MIOpen installed. + * The MIGraphX backend is an alternative AMD GPU backend using MIGraphX instead of MIOpen. More in detail: * OpenCL is a general GPU backend should be able to run with any GPUs or accelerators that support [OpenCL](https://en.wikipedia.org/wiki/OpenCL), including NVIDIA GPUs, AMD GPUs, as well CPU-based OpenCL implementations or things like Intel Integrated Graphics. This is the most general GPU version of KataGo and doesn't require a complicated install like CUDA does, so is most likely to work out of the box as long as you have a fairly modern GPU. **However, it also need to take some time when run for the very first time to tune itself.** For many systems, this will take 5-30 seconds, but on a few older/slower systems, may take many minutes or longer. Also, the quality of OpenCL implementations is sometimes inconsistent, particularly for Intel Integrated Graphics and for AMD GPUs that are older than several years, so it might not work for very old machines, as well as specific buggy newer AMD GPUs, see also [Issues with specific GPUs or GPU drivers](#issues-with-specific-gpus-or-gpu-drivers). * CUDA is a GPU backend specific to NVIDIA GPUs (it will not work with AMD or Intel or any other GPUs) and requires installing [CUDA](https://developer.nvidia.com/cuda-zone) and [CUDNN](https://developer.nvidia.com/cudnn) and a modern NVIDIA GPU. On most GPUs, the OpenCL implementation will actually beat NVIDIA's own CUDA/CUDNN at performance. The exception is for top-end NVIDIA GPUs that support FP16 and tensor cores, in which case sometimes one is better and sometimes the other is better. * TensorRT is similar to CUDA, but only uses NVIDIA's TensorRT framework to run the neural network with more optimized kernels. For modern NVIDIA GPUs, it should work whenever CUDA does and will usually be faster than CUDA or any other backend. + * ROCm is a GPU backend specific to AMD GPUs (it will not work with NVIDIA or Intel or any other GPUs) and requires installing [ROCm](https://rocm.docs.amd.com) and [MIOpen](https://rocm.docs.amd.com/projects/MIOpen) and a modern AMD GPU. Supports both **Linux** (via official ROCm packages, ROCm 6.4+) and **Windows** (via [AMD TheRock](https://github.com/ROCm/TheRock) builds). On most GPUs, the OpenCL implementation will actually beat AMD's own ROCm/MIOpen at performance. The exception is for top-end AMD GPUs that support FP16 and stream processors, in which case sometimes one is better and sometimes the other is better. + * MIGraphX is an alternative GPU backend for AMD GPUs using AMD's MIGraphX graph-compiler framework instead of MIOpen. It compiles the entire neural network into a single fused GPU program, which can offer better throughput than ROCm/MIOpen on some workloads. Requires ROCm 7.0+ with MIGraphX installed. Currently supports Linux only. * Eigen is a *CPU* backend that should work widely *without* needing a GPU or fancy drivers. Use this if you don't have a good GPU or really any GPU at all. It will be quite significantly slower than OpenCL or CUDA, but on a good CPU can still often get 10 to 20 playouts per second if using the smaller (15 or 20) block neural nets. Eigen can also be compiled with AVX2 and FMA support, which can provide a big performance boost for Intel and AMD CPUs from the last few years. However, it will not run at all on older CPUs (and possibly even some recent but low-power modern CPUs) that don't support these fancy vector instructions. For **any** implementation, it's recommended that you also tune the number of threads used if you care about optimal performance, as it can make a factor of 2-3 difference in the speed. See "Tuning for Performance" below. However, if you mostly just want to get it working, then the default untuned settings should also be still reasonable. @@ -175,6 +182,8 @@ This section summarizes a number of common questions and issues when running Kat #### Issues with specific GPUs or GPU drivers If you are observing any crashes in KataGo while attempting to run the benchmark or the program itself, and you have one of the below GPUs, then this is likely the reason. +* **AMD GPUs** - If you choose to use the ROCm backend, you need a GPU on the official [System requirements list](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html) (at least AMD Radeon RX 7700 XT). ROCm backend supports both Linux (via official ROCm packages) and Windows (via [AMD TheRock](https://github.com/ROCm/TheRock) builds). On Linux, install the full ROCm developer stack. On Windows, see the ROCm Windows build instructions in [Compiling.md](Compiling.md). The MIGraphX backend also requires ROCm 7.0+ with MIGraphX installed and currently supports Linux only. + * **AMD Radeon RX 5700** - AMD's drivers for OpenCL for this GPU have been buggy ever since this GPU was released, and as of May 2020 AMD has still never released a fix. If you are using this GPU, you will just not be able to run KataGo (Leela Zero and other Go engines will probably fail too) and will probably also obtain incorrect calculations or crash if doing anything else scientific or mathematical that uses OpenCL. See for example these reddit threads: [[1]](https://www.reddit.com/r/Amd/comments/ebso1x/its_not_just_setihome_any_mathematic_or/) or [[2]](https://www.reddit.com/r/BOINC/comments/ebiz18/psa_please_remove_your_amd_rx5700xt_from_setihome/) or this [L19 thread](https://lifein19x19.com/viewtopic.php?f=18&t=17093). * **OpenCL Mesa** - These drivers for OpenCL are buggy. Particularly if on startup before crashing you see KataGo printing something like `Found OpenCL Platform 0: ... (Mesa) (OpenCL 1.1 Mesa ...) ...` diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 8db79ca73..d63ca69e2 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,6 +1,141 @@ cmake_minimum_required(VERSION 3.18.2) if(USE_BACKEND STREQUAL "METAL") project(katago LANGUAGES CXX Swift) +elseif(USE_BACKEND STREQUAL "ROCM") + if(WIN32) + # Normalize HIP_PATH to forward slashes. + # cmake's HIP detection calls `hipconfig --rocmpath` which returns + # backslash paths on Windows; it writes the result into the generated + # CMakeHIPCompiler.cmake without normalizing, causing "Invalid character + # escape '\T'" errors. Pre-setting CMAKE_HIP_COMPILER_ROCM_ROOT (and + # CMAKE_PREFIX_PATH) before project() bypasses that detection entirely. + if(DEFINED ENV{HIP_PATH}) + file(TO_CMAKE_PATH "$ENV{HIP_PATH}" _hip_path_fwd) + set(ENV{HIP_PATH} "${_hip_path_fwd}") + list(APPEND CMAKE_PREFIX_PATH "${_hip_path_fwd}") + if(NOT CMAKE_HIP_COMPILER_ROCM_ROOT) + set(CMAKE_HIP_COMPILER_ROCM_ROOT "${_hip_path_fwd}" CACHE PATH "" FORCE) + endif() + endif() + # ---------- C/C++ compiler (clang++ from HIP SDK) ---------- + if(NOT CMAKE_CXX_COMPILER) + if(DEFINED ENV{HIP_PATH}) + if(EXISTS "$ENV{HIP_PATH}/lib/llvm/bin/clang++.exe") + set(CMAKE_CXX_COMPILER "$ENV{HIP_PATH}/lib/llvm/bin/clang++.exe" CACHE FILEPATH "" FORCE) + set(CMAKE_C_COMPILER "$ENV{HIP_PATH}/lib/llvm/bin/clang.exe" CACHE FILEPATH "" FORCE) + elseif(EXISTS "$ENV{HIP_PATH}/bin/clang++.exe") + set(CMAKE_CXX_COMPILER "$ENV{HIP_PATH}/bin/clang++.exe" CACHE FILEPATH "" FORCE) + set(CMAKE_C_COMPILER "$ENV{HIP_PATH}/bin/clang.exe" CACHE FILEPATH "" FORCE) + endif() + endif() + endif() + # HIP compiler = same binary as C++ compiler + if(NOT CMAKE_HIP_COMPILER AND CMAKE_CXX_COMPILER) + set(CMAKE_HIP_COMPILER "${CMAKE_CXX_COMPILER}" CACHE FILEPATH "" FORCE) + endif() + # ---------- HIP architectures (must be set before project() / enable_language(HIP)) ---------- + if(NOT DEFINED CMAKE_HIP_ARCHITECTURES AND DEFINED ENV{HIP_PATH}) + # TheRock layout: lib/llvm/bin/; standard HIP SDK layout: bin/ + foreach(_arch_candidate + "$ENV{HIP_PATH}/lib/llvm/bin/amdgpu-arch.exe" + "$ENV{HIP_PATH}/bin/amdgpu-arch.exe") + if(EXISTS "${_arch_candidate}") + set(_amdgpu_arch_exe "${_arch_candidate}") + break() + endif() + endforeach() + if(EXISTS "${_amdgpu_arch_exe}") + execute_process(COMMAND "${_amdgpu_arch_exe}" + OUTPUT_VARIABLE _detected_archs OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET) + if(_detected_archs) + string(REPLACE "\n" ";" _arch_list "${_detected_archs}") + # Filter to only valid gfxNNNN entries (amdgpu-arch may also print + # "HIP Library Path: ..." header lines on some installations) + set(_filtered_archs "") + foreach(_a ${_arch_list}) + if(_a MATCHES "^gfx[0-9]") + list(APPEND _filtered_archs "${_a}") + endif() + endforeach() + list(REMOVE_DUPLICATES _filtered_archs) + if(_filtered_archs) + set(CMAKE_HIP_ARCHITECTURES "${_filtered_archs}" CACHE STRING "Auto-detected AMD GPU targets") + message(STATUS "Pre-project auto-detected AMD GPU architectures: ${CMAKE_HIP_ARCHITECTURES}") + endif() + endif() + endif() + if(NOT CMAKE_HIP_ARCHITECTURES) + # Conservative fallback covering RDNA2/3/4 and CDNA2/3 + set(CMAKE_HIP_ARCHITECTURES "gfx1030;gfx1100;gfx1101;gfx1151;gfx1201;gfx90a;gfx942;gfx950" CACHE STRING "Fallback AMD GPU targets") + message(STATUS "amdgpu-arch not available; using fallback architectures: ${CMAKE_HIP_ARCHITECTURES}") + endif() + endif() + # ---------- Windows SDK includes (needed by HIP compiler test during project()) ---------- + # The HIP runtime wrapper includes MSVC headers that require Windows SDK ucrt/shared/um. + # These flags must be set before project() so cmake's HIP compiler test can compile. + if(NOT KATAGO_WINSDK_ROOT) + get_filename_component(_pre_winsdk_root + "[HKEY_LOCAL_MACHINE\\SOFTWARE\\Microsoft\\Windows Kits\\Installed Roots;KitsRoot10]" + ABSOLUTE) + if(NOT EXISTS "${_pre_winsdk_root}") + foreach(_p "C:/Program Files (x86)/Windows Kits/10" + "C:/Program Files/Windows Kits/10") + if(EXISTS "${_p}") + set(_pre_winsdk_root "${_p}") + break() + endif() + endforeach() + endif() + if(EXISTS "${_pre_winsdk_root}") + set(KATAGO_WINSDK_ROOT "${_pre_winsdk_root}" CACHE INTERNAL "") + endif() + endif() + if(KATAGO_WINSDK_ROOT) + file(GLOB _pre_sdk_ver_dirs "${KATAGO_WINSDK_ROOT}/Include/*/ucrt") + if(_pre_sdk_ver_dirs) + list(SORT _pre_sdk_ver_dirs ORDER DESCENDING) + list(GET _pre_sdk_ver_dirs 0 _pre_ucrt_dir) + get_filename_component(_pre_sdk_ver_dir "${_pre_ucrt_dir}" DIRECTORY) + set(_winsdk_cflags + "-I\"${_pre_sdk_ver_dir}/ucrt\" -I\"${_pre_sdk_ver_dir}/shared\" -I\"${_pre_sdk_ver_dir}/um\"") + set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} ${_winsdk_cflags}" CACHE STRING "" FORCE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${_winsdk_cflags}" CACHE STRING "" FORCE) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${_winsdk_cflags}" CACHE STRING "" FORCE) + message(STATUS "Pre-project: injected Windows SDK includes from ${_pre_sdk_ver_dir}") + endif() + endif() + # ---------- RC compiler (Windows resource compiler, highest SDK version) ---------- + if(NOT CMAKE_RC_COMPILER) + # Try registry first + get_filename_component(_winsdk_root + "[HKEY_LOCAL_MACHINE\\SOFTWARE\\Microsoft\\Windows Kits\\Installed Roots;KitsRoot10]" + ABSOLUTE) + if(NOT EXISTS "${_winsdk_root}") + foreach(_p "C:/Program Files (x86)/Windows Kits/10" + "C:/Program Files/Windows Kits/10") + if(EXISTS "${_p}") + set(_winsdk_root "${_p}") + break() + endif() + endforeach() + endif() + if(EXISTS "${_winsdk_root}") + file(GLOB _rc_candidates "${_winsdk_root}/bin/*/x64/rc.exe") + if(_rc_candidates) + list(SORT _rc_candidates ORDER DESCENDING) + list(GET _rc_candidates 0 _rc_exe) + set(CMAKE_RC_COMPILER "${_rc_exe}" CACHE FILEPATH "" FORCE) + endif() + # Persist for use in compiler flags section below + set(KATAGO_WINSDK_ROOT "${_winsdk_root}" CACHE INTERNAL "") + endif() + endif() + else() + # Linux: Use hipcc + set(CMAKE_C_COMPILER /opt/rocm/bin/hipcc CACHE FILEPATH "" FORCE) + set(CMAKE_CXX_COMPILER /opt/rocm/bin/hipcc CACHE FILEPATH "" FORCE) + endif() + project(katago LANGUAGES C CXX HIP) else() project(katago) endif() @@ -32,7 +167,7 @@ endif() set(BUILD_DISTRIBUTED 0 CACHE BOOL "Build with http support for contributing to distributed training") set(USE_BACKEND CACHE STRING "Neural net backend") string(TOUPPER "${USE_BACKEND}" USE_BACKEND) -set_property(CACHE USE_BACKEND PROPERTY STRINGS "" CUDA TENSORRT OPENCL EIGEN) +set_property(CACHE USE_BACKEND PROPERTY STRINGS "" CUDA TENSORRT OPENCL EIGEN ROCM MIGRAPHX) set(USE_TCMALLOC 0 CACHE BOOL "Use TCMalloc") set(NO_GIT_REVISION 0 CACHE BOOL "Disable embedding the git revision into the compiled exe") @@ -145,6 +280,140 @@ elseif(USE_BACKEND STREQUAL "EIGEN") set(NEURALNET_BACKEND_SOURCES neuralnet/eigenbackend.cpp ) +# --------------------------- ROCM backend(AMD GPU / HIP MIOpen) --------------------------- +elseif(USE_BACKEND STREQUAL "ROCM") + message(STATUS "-DUSE_BACKEND=ROCM, using AMD ROCm backend.") + + enable_language(HIP) + set(CMAKE_HIP_STANDARD 17) + + if(CMAKE_PREFIX_PATH STREQUAL "" OR NOT DEFINED CMAKE_PREFIX_PATH) + if(WIN32) + # Windows: HIP SDK installed via installer or manually + if(DEFINED ENV{HIP_PATH}) + list(APPEND CMAKE_PREFIX_PATH "$ENV{HIP_PATH}") + message(STATUS "Auto-detected HIP_PATH=$ENV{HIP_PATH} → CMAKE_PREFIX_PATH") + elseif(DEFINED ENV{ROCM_PATH}) + list(APPEND CMAKE_PREFIX_PATH "$ENV{ROCM_PATH}") + message(STATUS "Auto-detected ROCM_PATH=$ENV{ROCM_PATH} → CMAKE_PREFIX_PATH") + else() + message(WARNING "HIP_PATH or ROCM_PATH environment variable not set. Please install HIP SDK for Windows.") + endif() + else() + # Linux: Standard ROCm installation path + if(EXISTS "/opt/rocm") + list(APPEND CMAKE_PREFIX_PATH "/opt/rocm") + message(STATUS "CMAKE_PREFIX_PATH not given; defaulting to /opt/rocm") + endif() + endif() + endif() + + # Users can -DCMAKE_HIP_ARCHITECTURES=gfx90a;gfx942 manually specify GFX architectures + if(NOT DEFINED CMAKE_HIP_ARCHITECTURES) + # Auto-detect installed GPU architectures via amdgpu-arch + set(_amdgpu_arch_exe "") + if(WIN32 AND DEFINED ENV{HIP_PATH}) + foreach(_arch_cand + "$ENV{HIP_PATH}/lib/llvm/bin/amdgpu-arch.exe" + "$ENV{HIP_PATH}/bin/amdgpu-arch.exe") + if(EXISTS "${_arch_cand}") + set(_amdgpu_arch_exe "${_arch_cand}") + break() + endif() + endforeach() + elseif(EXISTS "/opt/rocm/bin/amdgpu-arch") + set(_amdgpu_arch_exe "/opt/rocm/bin/amdgpu-arch") + endif() + if(_amdgpu_arch_exe) + execute_process(COMMAND "${_amdgpu_arch_exe}" + OUTPUT_VARIABLE _detected_archs OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET) + if(_detected_archs) + string(REPLACE "\n" ";" _arch_list "${_detected_archs}") + set(_filtered_archs2 "") + foreach(_a ${_arch_list}) + if(_a MATCHES "^gfx[0-9]") + list(APPEND _filtered_archs2 "${_a}") + endif() + endforeach() + list(REMOVE_DUPLICATES _filtered_archs2) + if(_filtered_archs2) + set(CMAKE_HIP_ARCHITECTURES "${_filtered_archs2}" CACHE STRING "Auto-detected AMD GPU targets") + message(STATUS "Auto-detected AMD GPU architectures: ${CMAKE_HIP_ARCHITECTURES}") + endif() + endif() + endif() + if(NOT CMAKE_HIP_ARCHITECTURES) + # Fallback: compile for a broad range of supported architectures + add_compile_definitions(-DGPU_TARGETS=gfx950,gfx942,gfx90a,gfx908,gfx1100,gfx1101,gfx1151,gfx1201,gfx1030) + endif() + endif() + + # 2) Specify backend source code. rocmhelpers.hip contains GPU kernels, don't forget it + set(NEURALNET_BACKEND_SOURCES + neuralnet/rocmbackend.cpp + neuralnet/rocmutils.cpp + neuralnet/rocmhelpers.hip + ) + + # Optional: Enable model-size‑based autotuning and other macros + # add_compile_definitions(HIP_SUPPORTS_FP16) + +# --------------------------- MIGRAPHX backend(AMD MIGraphX graph inference) --------------------------- +elseif(USE_BACKEND STREQUAL "MIGRAPHX") + message(STATUS "-DUSE_BACKEND=MIGRAPHX, using AMD MIGraphX backend.") + + # Use standard C++ compiler with MIGraphX + set(CMAKE_CXX_STANDARD 17) + + # Find MIGraphX manually (avoid CMake config which adds hipcc-specific flags) + # Note: MIGraphX headers are split between two locations: + # - /opt/rocm/lib/migraphx/include/migraphx/ (C++ API headers like program.hpp) + # - /opt/rocm/include/migraphx/ (export.h and other common headers) + find_path(MIGRAPHX_CXX_INCLUDE_DIR migraphx/program.hpp + HINTS /opt/rocm/lib/migraphx/include + PATH_SUFFIXES include) + + find_path(MIGRAPHX_INCLUDE_DIR migraphx/export.h + HINTS /opt/rocm/include + PATH_SUFFIXES include) + + find_library(MIGRAPHX_LIBRARY migraphx + HINTS /opt/rocm/lib/migraphx/lib /opt/rocm/lib + PATH_SUFFIXES lib lib64) + + find_library(MIGRAPHX_GPU_LIBRARY migraphx_gpu + HINTS /opt/rocm/lib/migraphx/lib /opt/rocm/lib + PATH_SUFFIXES lib lib64) + + if(NOT MIGRAPHX_CXX_INCLUDE_DIR) + message(FATAL_ERROR "MIGraphX C++ headers not found. Please install MIGraphX.") + endif() + + if(NOT MIGRAPHX_LIBRARY) + message(FATAL_ERROR "MIGraphX library not found. Please install MIGraphX.") + endif() + + message(STATUS "MIGraphX C++ include: ${MIGRAPHX_CXX_INCLUDE_DIR}") + message(STATUS "MIGraphX include: ${MIGRAPHX_INCLUDE_DIR}") + message(STATUS "MIGraphX library: ${MIGRAPHX_LIBRARY}") + if(MIGRAPHX_GPU_LIBRARY) + message(STATUS "MIGraphX GPU library: ${MIGRAPHX_GPU_LIBRARY}") + endif() + + # Source files for MIGraphX backend + set(NEURALNET_BACKEND_SOURCES + neuralnet/migraphxbackend.cpp + ) + + # Include directories (both locations needed) + include_directories(SYSTEM ${MIGRAPHX_CXX_INCLUDE_DIR}) + if(MIGRAPHX_INCLUDE_DIR) + include_directories(SYSTEM ${MIGRAPHX_INCLUDE_DIR}) + endif() + + # Add ROCm lib directory for linking + link_directories(/opt/rocm/lib) + elseif(USE_BACKEND STREQUAL "") message(WARNING "${ColorBoldRed}WARNING: Using dummy neural net backend, intended for non-neural-net testing only, will fail on any code path requiring a neural net. To use neural net, specify -DUSE_BACKEND=CUDA or -DUSE_BACKEND=TENSORRT or -DUSE_BACKEND=OPENCL or -DUSE_BACKEND=EIGEN to compile with the respective backend.${ColorReset}") set(NEURALNET_BACKEND_SOURCES neuralnet/dummybackend.cpp) @@ -428,6 +697,110 @@ elseif(USE_BACKEND STREQUAL "OPENCL") link_directories(${OpenCL_LIBRARY}) target_link_libraries(katago ${OpenCL_LIBRARY}) endif() +# --------------------------- ROCM linking stage --------------------------- +elseif(USE_BACKEND STREQUAL "ROCM") + # Macro: used in source code with #ifdef USE_ROCM_BACKEND + target_compile_definitions(katago PRIVATE USE_ROCM_BACKEND) + target_compile_definitions(katago PRIVATE HIP_TARGET_VERSION=${CMAKE_HIP_COMPILER_VERSION}) + + string(TOLOWER "${CMAKE_HIP_ARCHITECTURES}" _gfxlist) # e.g. "90a;942" + if(_gfxlist MATCHES "803|900|90a|94[0-9]|110[0-9]|120[0-9]|115[0-9]|1030") + target_compile_definitions(katago PRIVATE HIP_SUPPORTS_FP16) + message(STATUS "Detected FP16‑capable GFX arch (${CMAKE_HIP_ARCHITECTURES}); defining HIP_SUPPORTS_FP16") + endif() + + # 3) Find ROCm runtime & libraries. Since ROCm 6.x, CMake config-mode packages are included. If not found, add -DCMAKE_PREFIX_PATH=/opt/rocm (Linux) or HIP SDK path (Windows) + find_package(hip QUIET CONFIG) # Export hip::device / hip::host + find_package(hipblas QUIET CONFIG) # Export roc::hipblas + find_package(miopen QUIET CONFIG) # Export roc::miopen or MIOpen + + # ---------- fallback:HIP Runtime ---------- + if(NOT hip_FOUND) + if(WIN32) + # Windows: Search in HIP SDK installation + find_path(HIP_INCLUDE_DIR hip/hip_runtime.h + HINTS ${CMAKE_PREFIX_PATH} ENV HIP_PATH ENV ROCM_PATH + PATH_SUFFIXES include) + find_library(HIP_RUNTIME_LIB + NAMES amdhip64 amdhip64_6 + HINTS ${CMAKE_PREFIX_PATH} ENV HIP_PATH ENV ROCM_PATH + PATH_SUFFIXES lib bin) + else() + # Linux: Search in /opt/rocm + find_path(HIP_INCLUDE_DIR hip/hip_runtime.h + HINTS ${CMAKE_PREFIX_PATH} /opt/rocm + PATH_SUFFIXES include) + find_library(HIP_RUNTIME_LIB amdhip64 + HINTS ${CMAKE_PREFIX_PATH} /opt/rocm + PATH_SUFFIXES lib lib64) + endif() + + if(NOT HIP_INCLUDE_DIR OR NOT HIP_RUNTIME_LIB) + if(WIN32) + message(FATAL_ERROR "HIP headers or runtime NOT found; install HIP SDK for Windows or set CMAKE_PREFIX_PATH to HIP SDK installation path.") + else() + message(FATAL_ERROR "HIP headers or runtime NOT found; install ROCm or set CMAKE_PREFIX_PATH.") + endif() + endif() + add_library(hip::device UNKNOWN IMPORTED) + set_target_properties(hip::device PROPERTIES + IMPORTED_LOCATION "${HIP_RUNTIME_LIB}" + INTERFACE_INCLUDE_DIRECTORIES "${HIP_INCLUDE_DIR}") + target_include_directories(katago SYSTEM PRIVATE ${HIP_INCLUDE_DIR}) + endif() + + # ---------- fallback:hipBLAS / MIOpen ---------- + foreach(_pkg hipblas miopen) + if(NOT ${_pkg}_FOUND) + if(WIN32) + # Windows naming conventions + if(_pkg STREQUAL "hipblas") + set(_lib_names hipblas) + else() + set(_lib_names MIOpen) + endif() + find_library(${_pkg}_LIB + NAMES ${_lib_names} + HINTS ${CMAKE_PREFIX_PATH} ENV HIP_PATH ENV ROCM_PATH + PATH_SUFFIXES lib bin) + else() + # Linux naming + find_library(${_pkg}_LIB ${_pkg} + HINTS ${CMAKE_PREFIX_PATH} /opt/rocm + PATH_SUFFIXES lib lib64) + endif() + + if(${_pkg}_LIB) + add_library(roc::${_pkg} UNKNOWN IMPORTED) + set_target_properties(roc::${_pkg} PROPERTIES + IMPORTED_LOCATION "${${_pkg}_LIB}") + target_include_directories(katago SYSTEM PRIVATE ${HIP_INCLUDE_DIR}) + message(STATUS "Found ${_pkg} at ${${_pkg}_LIB}") + else() + if(WIN32) + message(FATAL_ERROR "Required ROCm component ${_pkg} not found – install HIP SDK for Windows or set CMAKE_PREFIX_PATH.") + else() + message(FATAL_ERROR "Required ROCm component ${_pkg} not found – install it or set CMAKE_PREFIX_PATH.") + endif() + endif() + endif() + endforeach() + + # 4) Link libraries + # Note: On Windows, MIOpen might need to be linked as "MIOpen" directly if the target doesn't exist + if(TARGET MIOpen) + set(_miopen_target MIOpen) + elseif(TARGET roc::miopen) + set(_miopen_target roc::miopen) + else() + set(_miopen_target roc::miopen) + endif() + + target_link_libraries(katago + hip::device # HIP runtime & kernel offload + roc::hipblas # BLAS + ${_miopen_target} # DNN primitives + ) elseif(USE_BACKEND STREQUAL "EIGEN") target_compile_definitions(katago PRIVATE USE_EIGEN_BACKEND) if(NOT (MSVC)) @@ -449,6 +822,35 @@ elseif(USE_BACKEND STREQUAL "EIGEN") endif() endif() endif() +elseif(USE_BACKEND STREQUAL "MIGRAPHX") + target_compile_definitions(katago PRIVATE USE_MIGRAPHX_BACKEND) + + # Link MIGraphX libraries + target_link_libraries(katago ${MIGRAPHX_LIBRARY}) + if(MIGRAPHX_GPU_LIBRARY) + target_link_libraries(katago ${MIGRAPHX_GPU_LIBRARY}) + endif() + + # Link HIP runtime + find_library(AMDHIP64_LIBRARY amdhip64 + HINTS /opt/rocm/lib + PATH_SUFFIXES lib lib64) + if(AMDHIP64_LIBRARY) + target_link_libraries(katago ${AMDHIP64_LIBRARY}) + else() + target_link_libraries(katago amdhip64) + endif() + + # Link other required libraries + find_library(HIPRTC_LIBRARY hiprtc + HINTS /opt/rocm/lib + PATH_SUFFIXES lib lib64) + if(HIPRTC_LIBRARY) + target_link_libraries(katago ${HIPRTC_LIBRARY}) + endif() + + # Add ROCm library directories + link_directories(/opt/rocm/lib) endif() if(USE_BIGGER_BOARDS_EXPENSIVE) @@ -459,6 +861,21 @@ if(NO_GIT_REVISION AND (NOT BUILD_DISTRIBUTED)) target_compile_definitions(katago PRIVATE NO_GIT_REVISION) endif() +# On Windows ROCm builds, zlib is bundled inside the HIP SDK (TheRock layout) +if(WIN32 AND USE_BACKEND STREQUAL "ROCM" AND DEFINED ENV{HIP_PATH}) + if(NOT ZLIB_INCLUDE_DIR AND EXISTS "$ENV{HIP_PATH}/lib/rocm_sysdeps/include/zlib.h") + set(ZLIB_INCLUDE_DIR "$ENV{HIP_PATH}/lib/rocm_sysdeps/include" CACHE PATH "" FORCE) + endif() + if(NOT ZLIB_LIBRARY) + foreach(_zlib_name "zlibstatic.lib" "zlib.lib" "zlibstaticd.lib") + if(EXISTS "$ENV{HIP_PATH}/lib/rocm_sysdeps/lib/${_zlib_name}") + set(ZLIB_LIBRARY "$ENV{HIP_PATH}/lib/rocm_sysdeps/lib/${_zlib_name}" CACHE FILEPATH "" FORCE) + break() + endif() + endforeach() + endif() +endif() + find_package(ZLIB) if(ZLIB_FOUND) include_directories(${ZLIB_INCLUDE_DIRS}) @@ -557,7 +974,7 @@ if(MSVC) set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /STACK:8388608") elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang") message(STATUS "Setting up build for GNU, Clang or MinGW.") - if(NOT (${CMAKE_SYSTEM_PROCESSOR} MATCHES "(arm|aarch32|aarch64)")) + if(NOT (${CMAKE_SYSTEM_PROCESSOR} MATCHES "(arm|aarch32|aarch64)") AND NOT USE_BACKEND STREQUAL "ROCM") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpmath=sse") else() # For ARM architecture, as a hack, ensure that char is signed @@ -586,7 +1003,10 @@ elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "C else() message(STATUS "Enabling Clang-specific build options.") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wnull-dereference -Wdangling-else") - target_link_libraries(katago "atomic") + if(NOT WIN32) + # libatomic is a Linux GCC/Clang runtime; not needed (or available) on Windows + target_link_libraries(katago "atomic") + endif() endif() if(USE_TCMALLOC) @@ -598,3 +1018,41 @@ endif() target_include_directories(katago PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) +# On Windows ROCm builds, clang compiles all files with -x hip which pulls in +# MSVC-compatibility headers that require Windows SDK ucrt/shared/um headers. +if(WIN32 AND USE_BACKEND STREQUAL "ROCM") + # Prefer the KATAGO_WINSDK_ROOT detected earlier in the pre-project() block + if(NOT KATAGO_WINSDK_ROOT) + get_filename_component(_winsdk_root2 + "[HKEY_LOCAL_MACHINE\\SOFTWARE\\Microsoft\\Windows Kits\\Installed Roots;KitsRoot10]" + ABSOLUTE) + if(EXISTS "${_winsdk_root2}") + set(KATAGO_WINSDK_ROOT "${_winsdk_root2}" CACHE INTERNAL "") + else() + foreach(_p "C:/Program Files (x86)/Windows Kits/10" + "C:/Program Files/Windows Kits/10") + if(EXISTS "${_p}") + set(KATAGO_WINSDK_ROOT "${_p}" CACHE INTERNAL "") + break() + endif() + endforeach() + endif() + endif() + if(KATAGO_WINSDK_ROOT) + # Pick highest version directory + file(GLOB _sdk_ver_dirs "${KATAGO_WINSDK_ROOT}/Include/*/ucrt") + if(_sdk_ver_dirs) + list(SORT _sdk_ver_dirs ORDER DESCENDING) + list(GET _sdk_ver_dirs 0 _ucrt_dir) + get_filename_component(_sdk_ver_dir "${_ucrt_dir}" DIRECTORY) + set(_winsdk_include "${_sdk_ver_dir}") + message(STATUS "Auto-detected Windows SDK include root: ${_winsdk_include}") + target_include_directories(katago PRIVATE + "${_winsdk_include}/ucrt" + "${_winsdk_include}/shared" + "${_winsdk_include}/um" + ) + endif() + endif() +endif() + diff --git a/cpp/README.md b/cpp/README.md index 1f5d8d21f..eca41d070 100644 --- a/cpp/README.md +++ b/cpp/README.md @@ -15,7 +15,7 @@ Summary of source folders, in approximate dependency order, from lowest level to * `nninputs.{cpp,h}` - Implements the input features for the neural net. * `sgfmetadata.{cpp,h}` - Implements the input features for the [HumanSL neural net](https://github.com/lightvector/KataGo/blob/master/docs/Analysis_Engine.md#human-sl-analysis-guide), for conditioning on various SGF metadata about human players from training data. * `nninterface.h` - Common interface that is implemented by every low-level neural net backend. - * `{cuda,opencl,eigen,trt,dummy}backend.cpp` - Various backends. + * `{cuda,opencl,eigen,trt,rocm,mgx,metal,dummy}backend.cpp` - Various backends. * `nneval.{cpp,h}` - Top-level handle to the neural net used by the rest of the engine, implements thread-safe batching of queries. * `search` - The main search engine. * `timecontrols.cpp` - Basic handling of a few possible time controls. diff --git a/cpp/command/benchmark.cpp b/cpp/command/benchmark.cpp index 3100fb1b1..92f44aae9 100644 --- a/cpp/command/benchmark.cpp +++ b/cpp/command/benchmark.cpp @@ -265,6 +265,11 @@ int MainCmds::benchmark(const vector& args) { cout << "If you have a strong GPU capable of FP16 tensor cores (e.g. RTX2080), " << "using the Cuda version of KataGo instead may give a mild performance boost." << endl; #endif +#ifdef USE_ROCM_BACKEND + cout << "You are currently using the ROCm version of KataGo." << endl; + cout << "If you have a strong GPU capable of FP16 tensor cores (e.g. RX6900XT), " + << "using the ROCm version of KataGo instead may give a mild performance boost." << endl; +#endif #ifdef USE_EIGEN_BACKEND cout << "You are currently using the Eigen (CPU) version of KataGo. Due to having no GPU, it may be slow." << endl; #endif diff --git a/cpp/configs/analysis_example.cfg b/cpp/configs/analysis_example.cfg index edc5e8726..c3a70bd3c 100644 --- a/cpp/configs/analysis_example.cfg +++ b/cpp/configs/analysis_example.cfg @@ -219,9 +219,7 @@ nnRandomize = true # cudaUseNHWC = auto -# ------------------------------ -# Metal GPU settings -# ------------------------------ +# Metal GPU settings-------------------------------------- # These only apply when using the METAL version of KataGo. # For one Metal instance: KataGo will automatically use the default device. @@ -235,6 +233,55 @@ nnRandomize = true # The pattern continues for additional Metal instances. +# ROCm GPU settings-------------------------------------- +# These only apply when using the ROCm version of KataGo. + +# IF USING ONE GPU: optionally uncomment and change this if the GPU you want to use turns out to be not device 0 +# rocmDeviceToUse = 0 + +# IF USING TWO GPUS: Uncomment these two lines (AND set numNNServerThreadsPerModel above): +# rocmDeviceToUseThread0 = 0 # change this if the first GPU you want to use turns out to be not device 0 +# rocmDeviceToUseThread1 = 1 # change this if the second GPU you want to use turns out to be not device 1 + +# IF USING THREE GPUS: Uncomment these three lines (AND set numNNServerThreadsPerModel above): +# rocmDeviceToUseThread0 = 0 # change this if the first GPU you want to use turns out to be not device 0 +# rocmDeviceToUseThread1 = 1 # change this if the second GPU you want to use turns out to be not device 1 +# rocmDeviceToUseThread2 = 2 # change this if the third GPU you want to use turns out to be not device 2 + +# You can probably guess the pattern if you have four, five, etc. GPUs. + +# KataGo will automatically use FP16 or not based on the compute capability of your AMD GPU. If you +# want to try to force a particular behavior though you can uncomment these lines and change them +# to "true" or "false". E.g. it's using FP16 but on your card that's giving an error, or it's not using +# FP16 but you think it should. +# rocmUseFP16 = auto +# ROCm does not support NHWC, so this is always false. + + +# MIGraphX GPU settings-------------------------------------- +# These only apply when using the MIGraphX version of KataGo. + +# IF USING ONE GPU: optionally uncomment and change this if the GPU you want to use turns out to be not device 0 +# mgxDeviceToUse = 0 + +# IF USING TWO GPUS: Uncomment these two lines (AND set numNNServerThreadsPerModel above): +# mgxDeviceToUseThread0 = 0 # change this if the first GPU you want to use turns out to be not device 0 +# mgxDeviceToUseThread1 = 1 # change this if the second GPU you want to use turns out to be not device 1 + +# IF USING THREE GPUS: Uncomment these three lines (AND set numNNServerThreadsPerModel above): +# mgxDeviceToUseThread0 = 0 # change this if the first GPU you want to use turns out to be not device 0 +# mgxDeviceToUseThread1 = 1 # change this if the second GPU you want to use turns out to be not device 1 +# mgxDeviceToUseThread2 = 2 # change this if the third GPU you want to use turns out to be not device 2 + +# You can probably guess the pattern if you have four, five, etc. GPUs. + +# KataGo will automatically use FP16 or not based on the compute capability of your AMD GPU. If you +# want to try to force a particular behavior though you can uncomment these lines and change them +# to "true" or "false". E.g. it's using FP16 but on your card that's giving an error, or it's not using +# FP16 but you think it should. +# mgxUseFP16 = auto + + # OpenCL-specific GPU settings-------------------------------------- # These only apply when using the OpenCL version of KataGo. diff --git a/cpp/configs/contribute_example.cfg b/cpp/configs/contribute_example.cfg index 6ca039f11..5f2a2d1f8 100644 --- a/cpp/configs/contribute_example.cfg +++ b/cpp/configs/contribute_example.cfg @@ -83,9 +83,8 @@ watchOngoingGameInFileName = watchgame.txt # cudaUseNHWC = auto -# ------------------------------ -# Metal GPU settings -# ------------------------------ +# Metal GPU settings-------------------------------------- + # These only apply when using the METAL version of KataGo. # For one Metal instance: KataGo will automatically use the default device. @@ -99,6 +98,55 @@ watchOngoingGameInFileName = watchgame.txt # The pattern continues for additional Metal instances. +# ROCm GPU settings-------------------------------------- +# These only apply when using the ROCm version of KataGo. + +# IF USING ONE GPU: optionally uncomment and change this if the GPU you want to use turns out to be not device 0 +# rocmDeviceToUse = 0 + +# IF USING TWO GPUS: Uncomment these two lines (AND set numNNServerThreadsPerModel above): +# rocmDeviceToUseThread0 = 0 # change this if the first GPU you want to use turns out to be not device 0 +# rocmDeviceToUseThread1 = 1 # change this if the second GPU you want to use turns out to be not device 1 + +# IF USING THREE GPUS: Uncomment these three lines (AND set numNNServerThreadsPerModel above): +# rocmDeviceToUseThread0 = 0 # change this if the first GPU you want to use turns out to be not device 0 +# rocmDeviceToUseThread1 = 1 # change this if the second GPU you want to use turns out to be not device 1 +# rocmDeviceToUseThread2 = 2 # change this if the third GPU you want to use turns out to be not device 2 + +# You can probably guess the pattern if you have four, five, etc. GPUs. + +# KataGo will automatically use FP16 or not based on the compute capability of your AMD GPU. If you +# want to try to force a particular behavior though you can uncomment these lines and change them +# to "true" or "false". E.g. it's using FP16 but on your card that's giving an error, or it's not using +# FP16 but you think it should. +# rocmUseFP16 = auto +# ROCm does not support NHWC, so this is always false. + + +# MIGraphX GPU settings-------------------------------------- +# These only apply when using the MIGraphX version of KataGo. + +# IF USING ONE GPU: optionally uncomment and change this if the GPU you want to use turns out to be not device 0 +# mgxDeviceToUse = 0 + +# IF USING TWO GPUS: Uncomment these two lines (AND set numNNServerThreadsPerModel above): +# mgxDeviceToUseThread0 = 0 # change this if the first GPU you want to use turns out to be not device 0 +# mgxDeviceToUseThread1 = 1 # change this if the second GPU you want to use turns out to be not device 1 + +# IF USING THREE GPUS: Uncomment these three lines (AND set numNNServerThreadsPerModel above): +# mgxDeviceToUseThread0 = 0 # change this if the first GPU you want to use turns out to be not device 0 +# mgxDeviceToUseThread1 = 1 # change this if the second GPU you want to use turns out to be not device 1 +# mgxDeviceToUseThread2 = 2 # change this if the third GPU you want to use turns out to be not device 2 + +# You can probably guess the pattern if you have four, five, etc. GPUs. + +# KataGo will automatically use FP16 or not based on the compute capability of your AMD GPU. If you +# want to try to force a particular behavior though you can uncomment these lines and change them +# to "true" or "false". E.g. it's using FP16 but on your card that's giving an error, or it's not using +# FP16 but you think it should. +# mgxUseFP16 = auto + + # OpenCL GPU settings-------------------------------------- # These only apply when using the OpenCL version of KataGo. diff --git a/cpp/configs/gtp_example.cfg b/cpp/configs/gtp_example.cfg index cfa720bf3..c37901fa7 100644 --- a/cpp/configs/gtp_example.cfg +++ b/cpp/configs/gtp_example.cfg @@ -455,9 +455,9 @@ searchFactorWhenWinningThreshold = 0.95 # cudaUseFP16 = auto # cudaUseNHWC = auto -# ------------------------------ -# Metal GPU settings -# ------------------------------ + +# Metal GPU settings-------------------------------------- + # These only apply when using the METAL version of KataGo. # For one Metal instance: KataGo will automatically use the default device. @@ -470,6 +470,56 @@ searchFactorWhenWinningThreshold = 0.95 # The pattern continues for additional Metal instances. + +# ROCm GPU settings-------------------------------------- +# These only apply when using the ROCm version of KataGo. + +# IF USING ONE GPU: optionally uncomment and change this if the GPU you want to use turns out to be not device 0 +# rocmDeviceToUse = 0 + +# IF USING TWO GPUS: Uncomment these two lines (AND set numNNServerThreadsPerModel above): +# rocmDeviceToUseThread0 = 0 # change this if the first GPU you want to use turns out to be not device 0 +# rocmDeviceToUseThread1 = 1 # change this if the second GPU you want to use turns out to be not device 1 + +# IF USING THREE GPUS: Uncomment these three lines (AND set numNNServerThreadsPerModel above): +# rocmDeviceToUseThread0 = 0 # change this if the first GPU you want to use turns out to be not device 0 +# rocmDeviceToUseThread1 = 1 # change this if the second GPU you want to use turns out to be not device 1 +# rocmDeviceToUseThread2 = 2 # change this if the third GPU you want to use turns out to be not device 2 + +# You can probably guess the pattern if you have four, five, etc. GPUs. + +# KataGo will automatically use FP16 or not based on the compute capability of your AMD GPU. If you +# want to try to force a particular behavior though you can uncomment these lines and change them +# to "true" or "false". E.g. it's using FP16 but on your card that's giving an error, or it's not using +# FP16 but you think it should. +# rocmUseFP16 = auto +# ROCm does not support NHWC, so this is always false. + + +# MIGraphX GPU settings-------------------------------------- +# These only apply when using the MIGraphX version of KataGo. + +# IF USING ONE GPU: optionally uncomment and change this if the GPU you want to use turns out to be not device 0 +# mgxDeviceToUse = 0 + +# IF USING TWO GPUS: Uncomment these two lines (AND set numNNServerThreadsPerModel above): +# mgxDeviceToUseThread0 = 0 # change this if the first GPU you want to use turns out to be not device 0 +# mgxDeviceToUseThread1 = 1 # change this if the second GPU you want to use turns out to be not device 1 + +# IF USING THREE GPUS: Uncomment these three lines (AND set numNNServerThreadsPerModel above): +# mgxDeviceToUseThread0 = 0 # change this if the first GPU you want to use turns out to be not device 0 +# mgxDeviceToUseThread1 = 1 # change this if the second GPU you want to use turns out to be not device 1 +# mgxDeviceToUseThread2 = 2 # change this if the third GPU you want to use turns out to be not device 2 + +# You can probably guess the pattern if you have four, five, etc. GPUs. + +# KataGo will automatically use FP16 or not based on the compute capability of your AMD GPU. If you +# want to try to force a particular behavior though you can uncomment these lines and change them +# to "true" or "false". E.g. it's using FP16 but on your card that's giving an error, or it's not using +# FP16 but you think it should. +# mgxUseFP16 = auto + + # ------------------------------ # OpenCL GPU settings # ------------------------------ diff --git a/cpp/configs/match_example.cfg b/cpp/configs/match_example.cfg index 7e5b4fc09..b9e2895bb 100644 --- a/cpp/configs/match_example.cfg +++ b/cpp/configs/match_example.cfg @@ -156,9 +156,8 @@ numNNServerThreadsPerModel = 1 # cudaUseNHWC = auto -# ------------------------------ -# Metal GPU settings -# ------------------------------ +# Metal GPU settings-------------------------------------- + # These only apply when using the METAL version of KataGo. # For one Metal instance: KataGo will automatically use the default device. @@ -172,6 +171,55 @@ numNNServerThreadsPerModel = 1 # The pattern continues for additional Metal instances. +# ROCm GPU settings-------------------------------------- +# These only apply when using the ROCm version of KataGo. + +# IF USING ONE GPU: optionally uncomment and change this if the GPU you want to use turns out to be not device 0 +# rocmDeviceToUse = 0 + +# IF USING TWO GPUS: Uncomment these two lines (AND set numNNServerThreadsPerModel above): +# rocmDeviceToUseThread0 = 0 # change this if the first GPU you want to use turns out to be not device 0 +# rocmDeviceToUseThread1 = 1 # change this if the second GPU you want to use turns out to be not device 1 + +# IF USING THREE GPUS: Uncomment these three lines (AND set numNNServerThreadsPerModel above): +# rocmDeviceToUseThread0 = 0 # change this if the first GPU you want to use turns out to be not device 0 +# rocmDeviceToUseThread1 = 1 # change this if the second GPU you want to use turns out to be not device 1 +# rocmDeviceToUseThread2 = 2 # change this if the third GPU you want to use turns out to be not device 2 + +# You can probably guess the pattern if you have four, five, etc. GPUs. + +# KataGo will automatically use FP16 or not based on the compute capability of your AMD GPU. If you +# want to try to force a particular behavior though you can uncomment these lines and change them +# to "true" or "false". E.g. it's using FP16 but on your card that's giving an error, or it's not using +# FP16 but you think it should. +# rocmUseFP16 = auto +# ROCm does not support NHWC, so this is always false. + + +# MIGraphX GPU settings-------------------------------------- +# These only apply when using the MIGraphX version of KataGo. + +# IF USING ONE GPU: optionally uncomment and change this if the GPU you want to use turns out to be not device 0 +# mgxDeviceToUse = 0 + +# IF USING TWO GPUS: Uncomment these two lines (AND set numNNServerThreadsPerModel above): +# mgxDeviceToUseThread0 = 0 # change this if the first GPU you want to use turns out to be not device 0 +# mgxDeviceToUseThread1 = 1 # change this if the second GPU you want to use turns out to be not device 1 + +# IF USING THREE GPUS: Uncomment these three lines (AND set numNNServerThreadsPerModel above): +# mgxDeviceToUseThread0 = 0 # change this if the first GPU you want to use turns out to be not device 0 +# mgxDeviceToUseThread1 = 1 # change this if the second GPU you want to use turns out to be not device 1 +# mgxDeviceToUseThread2 = 2 # change this if the third GPU you want to use turns out to be not device 2 + +# You can probably guess the pattern if you have four, five, etc. GPUs. + +# KataGo will automatically use FP16 or not based on the compute capability of your AMD GPU. If you +# want to try to force a particular behavior though you can uncomment these lines and change them +# to "true" or "false". E.g. it's using FP16 but on your card that's giving an error, or it's not using +# FP16 but you think it should. +# mgxUseFP16 = auto + + # OpenCL GPU settings-------------------------------------- # These only apply when using OpenCL as the backend for inference. # (For GTP, we only ever have one model, when playing matches, we might have more than one, see match_example.cfg) diff --git a/cpp/main.cpp b/cpp/main.cpp index 0fcc36dea..688f301a7 100644 --- a/cpp/main.cpp +++ b/cpp/main.cpp @@ -246,6 +246,15 @@ string Version::getKataGoVersionFullInfo() { out << "Using Metal backend" << endl; #elif defined(USE_OPENCL_BACKEND) out << "Using OpenCL backend" << endl; +#elif defined(USE_ROCM_BACKEND) + out << "Using ROCm backend" << endl; +#if defined(HIP_TARGET_VERSION) +#define STRINGIFY(x) #x +#define STRINGIFY2(x) STRINGIFY(x) + out << "Compiled with HIP runtime version " << STRINGIFY2(HIP_TARGET_VERSION) << endl; +#endif +#elif defined(USE_MIGRAPHX_BACKEND) + out << "Using MIGraphX backend" << endl; #elif defined(USE_EIGEN_BACKEND) out << "Using Eigen(CPU) backend" << endl; #else @@ -278,6 +287,8 @@ string Version::getGitRevisionWithBackend() { s += "-cuda"; #elif defined(USE_TENSORRT_BACKEND) s += "-trt"; +#elif defined(USE_ROCM_BACKEND) + s += "-rocm"; #elif defined(USE_METAL_BACKEND) s += "-metal"; #elif defined(USE_OPENCL_BACKEND) diff --git a/cpp/neuralnet/migraphxbackend.cpp b/cpp/neuralnet/migraphxbackend.cpp new file mode 100644 index 000000000..3eee11926 --- /dev/null +++ b/cpp/neuralnet/migraphxbackend.cpp @@ -0,0 +1,1886 @@ +#include "../neuralnet/nninterface.h" +#include "../neuralnet/nninputs.h" +#include "../neuralnet/nneval.h" +#include "../neuralnet/modelversion.h" +#include "../neuralnet/desc.h" +#include "../neuralnet/sgfmetadata.h" +#include "../neuralnet/activations.h" +#include "../neuralnet/activations.h" + +#include "../core/fileutils.h" +#include "../core/makedir.h" +#include "../core/sha2.h" +#include "../dataio/homedata.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; + +//------------------------ MIGraphX Backend Documentation ------------------------ +// +// This is a MIGraphX backend implementation for KataGo. +// +// Current Status: +// - Full model weight loading from ModelDesc +// - Complete residual network structure (28 blocks for b28c512nbt) +// - Input/output tensor handling +// - Working inference with MIGraphX GPU backend +// +// Known Limitations: +// - BatchNorm is simplified (skipped) due to MIGraphX broadcast limitations +// - Global pooling residual blocks use simplified implementation +// - Value/Score/Ownership heads use simplified projections +// +// Future Optimizations: +// - Implement proper BatchNorm with broadcast +// - Full global pooling residual block implementation +// - Complete value head with v2Mul/v3Mul layers +// - FP16 support for faster inference +// +//------------------------ MIGraphX Model Implementation ------------------------ + +struct MIGraphXModel { + // Multiple compiled programs for different batch sizes + // Key: batch size, Value: compiled program + map progs; + migraphx::target tgt; + // Sorted batch sizes for quick lookup + vector batchSizes; + + int modelVersion; + int maxBatchSize; + int nnXLen, nnYLen; + bool useFP16; + bool useNHWC; + + int numInputChannels; + int numInputGlobalChannels; + int numInputMetaChannels; + int numPolicyChannels; + int numValueChannels; + int numScoreValueChannels; + int numOwnershipChannels; + + MIGraphXModel() + : modelVersion(0), maxBatchSize(1), nnXLen(19), nnYLen(19), + useFP16(false), useNHWC(false), + numInputChannels(0), numInputGlobalChannels(0), numInputMetaChannels(0), + numPolicyChannels(0), numValueChannels(3), + numScoreValueChannels(0), numOwnershipChannels(0) {} + + // Find the best (smallest sufficient) batch size for the given actual batch + int getBestBatchSize(int actualBatch) const { + for(int bs : batchSizes) { + if(bs >= actualBatch) return bs; + } + return batchSizes.back(); + } + + migraphx::program& getProgram(int batchSize) { + return progs.at(batchSize); + } +}; + +// Helper class to build MIGraphX graph +class MIGraphXGraphBuilder { +public: + migraphx::module* main_module; + migraphx::shape::type_t dataType; + int batchSize; + int nnXLen, nnYLen; + + MIGraphXGraphBuilder(migraphx::module* mod, migraphx::shape::type_t dtype, int batch, int x, int y) + : main_module(mod), dataType(dtype), batchSize(batch), nnXLen(x), nnYLen(y) {} + + // Add a convolution layer + migraphx::instruction_ref addConv( + migraphx::instruction_ref input, + const ConvLayerDesc& convDesc + ) { + // Validate dimensions + if(convDesc.inChannels <= 0 || convDesc.inChannels > 10000 || + convDesc.outChannels <= 0 || convDesc.outChannels > 10000 || + convDesc.convYSize <= 0 || convDesc.convYSize > 100 || + convDesc.convXSize <= 0 || convDesc.convXSize > 100) { + cerr << "ERROR: Conv " << convDesc.name << " has invalid dimensions (in=" << convDesc.inChannels + << ", out=" << convDesc.outChannels << ", ky=" << convDesc.convYSize + << ", kx=" << convDesc.convXSize << ")" << endl; + return input; + } + + vector wShape = { + (size_t)convDesc.outChannels, + (size_t)convDesc.inChannels, + (size_t)convDesc.convYSize, + (size_t)convDesc.convXSize + }; + size_t expectedWeights = (size_t)convDesc.outChannels * (size_t)convDesc.inChannels + * (size_t)convDesc.convYSize * (size_t)convDesc.convXSize; + + if(convDesc.weights.size() != expectedWeights) { + cerr << "ERROR: Conv " << convDesc.name << " weights size mismatch: " + << convDesc.weights.size() << " vs expected " << expectedWeights + << " (out=" << convDesc.outChannels << ", in=" << convDesc.inChannels + << ", ky=" << convDesc.convYSize << ", kx=" << convDesc.convXSize << ")" << endl; + return input; // Return input to avoid crash + } + + auto weights = addLiteral(convDesc.weights, wShape); + + int padY = (convDesc.convYSize - 1) / 2 * convDesc.dilationY; + int padX = (convDesc.convXSize - 1) / 2 * convDesc.dilationX; + + // Use vector for array values + vector padding = {(size_t)padY, (size_t)padX}; + vector stride = {1, 1}; + vector dilation = {(size_t)convDesc.dilationY, (size_t)convDesc.dilationX}; + + auto conv_op = migraphx::make_op("convolution", { + {"padding", migraphx::value(padding)}, + {"stride", migraphx::value(stride)}, + {"dilation", migraphx::value(dilation)}, + {"group", 1} + }); + + return main_module->add_instruction(conv_op, input, weights); + } + + // Add batch normalization (inference mode) - full implementation using multibroadcast + migraphx::instruction_ref addBatchNorm( + migraphx::instruction_ref input, + const BatchNormLayerDesc& bnDesc + ) { + // Skip if BN has no channels or invalid weights + if(bnDesc.numChannels <= 0 || bnDesc.numChannels > 10000) { + cerr << "WARNING: BatchNorm " << bnDesc.name << " has invalid numChannels=" << bnDesc.numChannels + << ", skipping BN" << endl; + return input; + } + + int numChannels = bnDesc.numChannels; + + // Validate weight sizes match numChannels + if(bnDesc.mergedScale.size() != (size_t)numChannels || bnDesc.mergedBias.size() != (size_t)numChannels) { + cerr << "WARNING: BatchNorm " << bnDesc.name << " weight size mismatch (C=" << numChannels + << ", scale=" << bnDesc.mergedScale.size() << ", bias=" << bnDesc.mergedBias.size() + << "), skipping BN" << endl; + return input; + } + + // Create scale and bias literals from mergedScale and mergedBias + vector paramShape = {(size_t)numChannels}; + auto scale = addLiteral(bnDesc.mergedScale, paramShape); + auto bias = addLiteral(bnDesc.mergedBias, paramShape); + + // Get input shape for broadcasting + auto input_shape = input->get_shape(); + vector input_lens = input_shape.lens(); + + // Unsqueeze scale and bias from [C] to [1, C, 1, 1] for broadcasting + auto scale_unsqueezed = main_module->add_instruction( + migraphx::make_op("unsqueeze", {{"axes", migraphx::value(vector{0, 2, 3})}}), scale); + auto bias_unsqueezed = main_module->add_instruction( + migraphx::make_op("unsqueeze", {{"axes", migraphx::value(vector{0, 2, 3})}}), bias); + + // Broadcast scale and bias to input shape using multibroadcast + // Input is NCHW: [batch, channels, height, width] + auto scale_broadcast = main_module->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), scale_unsqueezed); + auto bias_broadcast = main_module->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), bias_unsqueezed); + + // Apply scale and bias: y = x * scale + bias + auto scaled = main_module->add_instruction(migraphx::make_op("mul"), input, scale_broadcast); + auto result = main_module->add_instruction(migraphx::make_op("add"), scaled, bias_broadcast); + + return result; + } + + // Add MatMul layer + migraphx::instruction_ref addMatMul( + migraphx::instruction_ref input, + const MatMulLayerDesc& matmulDesc, + const MatBiasLayerDesc* biasDesc = nullptr + ) { + // Validate channel counts + if(matmulDesc.inChannels <= 0 || matmulDesc.inChannels > 10000 || + matmulDesc.outChannels <= 0 || matmulDesc.outChannels > 10000) { + cerr << "ERROR: MatMul " << matmulDesc.name << " has invalid channels (in=" + << matmulDesc.inChannels << ", out=" << matmulDesc.outChannels << ")" << endl; + return input; + } + + vector wShape = {(size_t)matmulDesc.inChannels, (size_t)matmulDesc.outChannels}; + size_t expectedWeights = (size_t)matmulDesc.inChannels * (size_t)matmulDesc.outChannels; + if(matmulDesc.weights.size() != expectedWeights) { + cerr << "ERROR: MatMul " << matmulDesc.name << " weights size mismatch: " + << matmulDesc.weights.size() << " vs expected " << expectedWeights + << " (in=" << matmulDesc.inChannels << ", out=" << matmulDesc.outChannels << ")" << endl; + // Return input to avoid crash (this will break the model but prevent segfault) + return input; + } + auto weights = addLiteral(matmulDesc.weights, wShape); + + auto matmul = main_module->add_instruction(migraphx::make_op("dot"), input, weights); + + if(biasDesc != nullptr && !biasDesc->weights.empty()) { + if(biasDesc->weights.size() != (size_t)biasDesc->numChannels) { + cerr << "ERROR: MatMul bias " << biasDesc->name << " size mismatch: " + << biasDesc->weights.size() << " vs expected " << biasDesc->numChannels << endl; + } else { + vector bShape = {(size_t)biasDesc->numChannels}; + auto bias = addLiteral(biasDesc->weights, bShape); + + // Unsqueeze for broadcasting: [numChannels] -> [1, numChannels] + auto unsqueeze_op = migraphx::make_op("unsqueeze", {{"axes", migraphx::value({0})}}); + bias = main_module->add_instruction(unsqueeze_op, bias); + + // Explicit broadcast to match matmul output shape + auto matmulShape = matmul->get_shape().lens(); + bias = main_module->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", matmulShape}}), bias); + + matmul = main_module->add_instruction(migraphx::make_op("add"), matmul, bias); + } + } + + return matmul; + } + + // Add activation + migraphx::instruction_ref addActivation(migraphx::instruction_ref input, int activationType) { + if(activationType == ACTIVATION_IDENTITY) { + return input; + } + else if(activationType == ACTIVATION_RELU) { + return main_module->add_instruction(migraphx::make_op("relu"), input); + } + else if(activationType == ACTIVATION_MISH) { + return addMish(input); + } + else if(activationType == ACTIVATION_MISH_SCALE8) { + return addMishScale8(input); + } + // Fallback to relu + return main_module->add_instruction(migraphx::make_op("relu"), input); + } + + // Mish activation: x * tanh(softplus(x)) = x * tanh(log(1 + exp(x))) + migraphx::instruction_ref addMish(migraphx::instruction_ref input) { + auto inputLens = input->get_shape().lens(); + // softplus(x) = log(1 + exp(x)) + auto exp_x = main_module->add_instruction(migraphx::make_op("exp"), input); + auto ones = broadcastScalar(1.0f, inputLens); + auto one_plus_exp = main_module->add_instruction(migraphx::make_op("add"), exp_x, ones); + auto softplus = main_module->add_instruction(migraphx::make_op("log"), one_plus_exp); + auto tanh_sp = main_module->add_instruction(migraphx::make_op("tanh"), softplus); + return main_module->add_instruction(migraphx::make_op("mul"), input, tanh_sp); + } + + // Mish-scale8 activation: x * tanh(softplus(clamp(8x, -, 30))) + // For x >= 2.5: tanh(softplus(20+)) ≈ 1, so result ≈ x (identity) + // For x < 2.5: standard mish with 8x scaling of softplus argument + migraphx::instruction_ref addMishScale8(migraphx::instruction_ref input) { + auto inputLens = input->get_shape().lens(); + // scaled = 8 * x, clamped to max 30 to prevent exp overflow + auto eight = broadcastScalar(8.0f, inputLens); + auto scaled = main_module->add_instruction(migraphx::make_op("mul"), input, eight); + auto thirty = broadcastScalar(30.0f, inputLens); + scaled = main_module->add_instruction(migraphx::make_op("min"), scaled, thirty); + // softplus(scaled) = log(1 + exp(scaled)) + auto exp_s = main_module->add_instruction(migraphx::make_op("exp"), scaled); + auto ones = broadcastScalar(1.0f, inputLens); + auto one_plus_exp = main_module->add_instruction(migraphx::make_op("add"), exp_s, ones); + auto softplus = main_module->add_instruction(migraphx::make_op("log"), one_plus_exp); + auto tanh_sp = main_module->add_instruction(migraphx::make_op("tanh"), softplus); + return main_module->add_instruction(migraphx::make_op("mul"), input, tanh_sp); + } + + // Helper: broadcast a scalar to the given shape + migraphx::instruction_ref broadcastScalar(float val, const vector& targetLens) { + vector onesShape(targetLens.size(), 1); + auto lit = addLiteral({val}, onesShape); + return main_module->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", targetLens}}), lit); + } + + // Add literal + migraphx::instruction_ref addLiteral(const vector& data, const vector& dims) { + migraphx::shape s(dataType, dims); + return main_module->add_literal(migraphx::literal(s, data)); + } + + // Convert tensor to specified data type + migraphx::instruction_ref addConvert(migraphx::instruction_ref input, migraphx::shape::type_t targetType) { + if(input->get_shape().type() == targetType) { + return input; + } + auto convert_op = migraphx::make_op("convert", {{"target_type", targetType}}); + return main_module->add_instruction(convert_op, input); + } + + // Global average pooling + migraphx::instruction_ref addGlobalAvgPool(migraphx::instruction_ref input) { + auto pool_op = migraphx::make_op("pooling", { + {"mode", 0}, // average + {"padding", migraphx::value({0, 0})}, + {"stride", migraphx::value({(size_t)nnYLen, (size_t)nnXLen})}, + {"lengths", migraphx::value({(size_t)nnYLen, (size_t)nnXLen})} + }); + return main_module->add_instruction(pool_op, input); + } + + // Flatten + migraphx::instruction_ref addFlatten(migraphx::instruction_ref input, size_t axis = 1) { + auto flatten_op = migraphx::make_op("flatten", {{"axis", axis}}); + return main_module->add_instruction(flatten_op, input); + } + + // Squeeze + migraphx::instruction_ref addSqueeze(migraphx::instruction_ref input, const vector& axes) { + auto squeeze_op = migraphx::make_op("squeeze", {{"axes", migraphx::value(axes)}}); + return main_module->add_instruction(squeeze_op, input); + } + + // Tanh + migraphx::instruction_ref addTanh(migraphx::instruction_ref input) { + return main_module->add_instruction(migraphx::make_op("tanh"), input); + } + + // Reduce sum over specified axes + migraphx::instruction_ref addReduceSum(migraphx::instruction_ref input, const vector& axes) { + auto reduce_op = migraphx::make_op("reduce_sum", {{"axes", migraphx::value(axes)}}); + return main_module->add_instruction(reduce_op, input); + } + + // Reduce max over specified axes + migraphx::instruction_ref addReduceMax(migraphx::instruction_ref input, const vector& axes) { + auto reduce_op = migraphx::make_op("reduce_max", {{"axes", migraphx::value(axes)}}); + return main_module->add_instruction(reduce_op, input); + } + + // Reduce mean over specified axes + migraphx::instruction_ref addReduceMean(migraphx::instruction_ref input, const vector& axes) { + auto reduce_op = migraphx::make_op("reduce_mean", {{"axes", migraphx::value(axes)}}); + return main_module->add_instruction(reduce_op, input); + } + + // Element-wise multiplication + migraphx::instruction_ref addMul(migraphx::instruction_ref a, migraphx::instruction_ref b) { + return main_module->add_instruction(migraphx::make_op("mul"), a, b); + } + + // Element-wise addition + migraphx::instruction_ref addAdd(migraphx::instruction_ref a, migraphx::instruction_ref b) { + return main_module->add_instruction(migraphx::make_op("add"), a, b); + } + + // Element-wise subtraction + migraphx::instruction_ref addSub(migraphx::instruction_ref a, migraphx::instruction_ref b) { + return main_module->add_instruction(migraphx::make_op("sub"), a, b); + } + + // Element-wise division + migraphx::instruction_ref addDiv(migraphx::instruction_ref a, migraphx::instruction_ref b) { + return main_module->add_instruction(migraphx::make_op("div"), a, b); + } + + // Power operation + migraphx::instruction_ref addPow(migraphx::instruction_ref input, float exponent) { + vector expData = {exponent}; + auto expLit = addLiteral(expData, {1, 1, 1, 1}); + return main_module->add_instruction(migraphx::make_op("pow"), input, expLit); + } + + // Sqrt operation + migraphx::instruction_ref addSqrt(migraphx::instruction_ref input) { + return main_module->add_instruction(migraphx::make_op("sqrt"), input); + } + + // Transpose operation + migraphx::instruction_ref addTranspose(migraphx::instruction_ref input, const vector& dims) { + auto transpose_op = migraphx::make_op("transpose", {{"dims", migraphx::value(dims)}}); + return main_module->add_instruction(transpose_op, input); + } + + // Concatenate along axis + migraphx::instruction_ref addConcat(const vector& inputs, int64_t axis) { + auto concat_op = migraphx::make_op("concat", {{"axis", axis}}); + return main_module->add_instruction(concat_op, inputs); + } + + // Global pooling producing 3 features per channel. + // For trunk/policy: [mean, mean*scale1, max] + // For value head: [mean, mean*scale1, mean*scale2] + // Input: [batch, C, H, W], Output: [batch, C*3] + // Note: assumes full board (no mask), correct for standard play at nnXLen x nnYLen. + migraphx::instruction_ref addGPool(migraphx::instruction_ref input, bool isValueHead = false) { + float boardArea = (float)(nnXLen * nnYLen); + float sqrtBoardArea = sqrtf(boardArea); + float scale1Factor = (sqrtBoardArea - 14.0f) * 0.1f; + + // mean: [batch, C, H, W] -> [batch, C, 1, 1] -> [batch, C] + auto mean = addReduceMean(input, {2, 3}); + mean = addSqueeze(mean, {2, 3}); + + auto meanShape = mean->get_shape().lens(); + + // scale1 = mean * scale1Factor + auto scale1Lit = addLiteral({scale1Factor}, {1, 1}); + auto scale1Broadcast = main_module->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", meanShape}}), scale1Lit); + auto scale1 = main_module->add_instruction(migraphx::make_op("mul"), mean, scale1Broadcast); + + migraphx::instruction_ref third; + if(isValueHead) { + // scale2 = mean * ((sqrtBoardArea - 14)^2 * 0.01 - 0.1) + float scale2Factor = (sqrtBoardArea - 14.0f) * (sqrtBoardArea - 14.0f) * 0.01f - 0.1f; + auto scale2Lit = addLiteral({scale2Factor}, {1, 1}); + auto scale2Broadcast = main_module->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", meanShape}}), scale2Lit); + third = main_module->add_instruction(migraphx::make_op("mul"), mean, scale2Broadcast); + } else { + // max: [batch, C, H, W] -> [batch, C, 1, 1] -> [batch, C] + auto maxVal = addReduceMax(input, {2, 3}); + third = addSqueeze(maxVal, {2, 3}); + } + + // Concat [mean, scale1, third] along axis 1 -> [batch, C*3] + return addConcat({mean, scale1, third}, 1); + } + +}; + +// Build residual block +static migraphx::instruction_ref buildResidualBlock( + MIGraphXGraphBuilder& builder, + migraphx::instruction_ref input, + const ResidualBlockDesc& blockDesc +) { + auto residual = input; + + // preBN + preActivation + auto x = builder.addBatchNorm(input, blockDesc.preBN); + x = builder.addActivation(x, blockDesc.preActivation.activation); + + // regularConv + x = builder.addConv(x, blockDesc.regularConv); + x = builder.addBatchNorm(x, blockDesc.midBN); + + // midActivation + x = builder.addActivation(x, blockDesc.midActivation.activation); + + // finalConv + x = builder.addConv(x, blockDesc.finalConv); + + // Add residual + return builder.main_module->add_instruction(migraphx::make_op("add"), x, residual); +} + +// Forward declarations +static migraphx::instruction_ref buildResidualBlock( + MIGraphXGraphBuilder& builder, + migraphx::instruction_ref input, + const ResidualBlockDesc& blockDesc +); + +static migraphx::instruction_ref buildGlobalPoolingResidualBlock( + MIGraphXGraphBuilder& builder, + migraphx::instruction_ref input, + const GlobalPoolingResidualBlockDesc& blockDesc +); + +static migraphx::instruction_ref buildNestedBottleneckResidualBlock( + MIGraphXGraphBuilder& builder, + migraphx::instruction_ref input, + const NestedBottleneckResidualBlockDesc& blockDesc +); + +static migraphx::instruction_ref buildResidualBlockStack( + MIGraphXGraphBuilder& builder, + migraphx::instruction_ref input, + const std::vector>& blocks, + const string& namePrefix +); + +// Build nested bottleneck residual block +static migraphx::instruction_ref buildNestedBottleneckResidualBlock( + MIGraphXGraphBuilder& builder, + migraphx::instruction_ref input, + const NestedBottleneckResidualBlockDesc& blockDesc +) { + auto residual = input; + + // Pre BN + Activation + auto x = builder.addBatchNorm(input, blockDesc.preBN); + x = builder.addActivation(x, blockDesc.preActivation.activation); + + // Pre conv (bottleneck down) + x = builder.addConv(x, blockDesc.preConv); + + // Inner residual block stack + x = buildResidualBlockStack(builder, x, blockDesc.blocks, blockDesc.name); + + // Post BN + Activation + x = builder.addBatchNorm(x, blockDesc.postBN); + x = builder.addActivation(x, blockDesc.postActivation.activation); + + // Post conv (bottleneck up) + x = builder.addConv(x, blockDesc.postConv); + + // Add residual + return builder.main_module->add_instruction(migraphx::make_op("add"), x, residual); +} + +// Build residual block stack (used by trunk and nested blocks) +static migraphx::instruction_ref buildResidualBlockStack( + MIGraphXGraphBuilder& builder, + migraphx::instruction_ref input, + const std::vector>& blocks, + const string& namePrefix +) { + auto trunk = input; + + for(size_t i = 0; i < blocks.size(); i++) { + int blockKind = blocks[i].first; + + if(blockKind == ORDINARY_BLOCK_KIND) { + const ResidualBlockDesc* blockDesc = static_cast(blocks[i].second.get()); + trunk = buildResidualBlock(builder, trunk, *blockDesc); + } else if(blockKind == GLOBAL_POOLING_BLOCK_KIND) { + const GlobalPoolingResidualBlockDesc* blockDesc = static_cast(blocks[i].second.get()); + trunk = buildGlobalPoolingResidualBlock(builder, trunk, *blockDesc); + } else if(blockKind == NESTED_BOTTLENECK_BLOCK_KIND) { + const NestedBottleneckResidualBlockDesc* blockDesc = static_cast(blocks[i].second.get()); + trunk = buildNestedBottleneckResidualBlock(builder, trunk, *blockDesc); + } + + } + + return trunk; +} + +// Build global pooling residual block - full implementation +static migraphx::instruction_ref buildGlobalPoolingResidualBlock( + MIGraphXGraphBuilder& builder, + migraphx::instruction_ref input, + const GlobalPoolingResidualBlockDesc& blockDesc +) { + auto residual = input; + + // preBN + preActivation + auto x = builder.addBatchNorm(input, blockDesc.preBN); + x = builder.addActivation(x, blockDesc.preActivation.activation); + + // Branch A: regular spatial conv + auto regularOut = builder.addConv(x, blockDesc.regularConv); + + // Branch B: global pooling conv + auto gpoolOut = builder.addConv(x, blockDesc.gpoolConv); + gpoolOut = builder.addBatchNorm(gpoolOut, blockDesc.gpoolBN); + gpoolOut = builder.addActivation(gpoolOut, blockDesc.gpoolActivation.activation); + + // Global pool: [batch, gpoolC, H, W] -> [batch, gpoolC*3] + auto gpoolFeatures = builder.addGPool(gpoolOut, false); + + // gpoolToBiasMul: [batch, gpoolC*3] -> [batch, regularC] + auto bias = builder.addMatMul(gpoolFeatures, blockDesc.gpoolToBiasMul); + + // Broadcast bias to spatial dims and add to regularOut + auto regularShape = regularOut->get_shape().lens(); + auto biasUnsqueezed = builder.main_module->add_instruction( + migraphx::make_op("unsqueeze", {{"axes", migraphx::value(vector{2, 3})}}), bias); + auto biasBroadcast = builder.main_module->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", regularShape}}), biasUnsqueezed); + regularOut = builder.main_module->add_instruction(migraphx::make_op("add"), regularOut, biasBroadcast); + + // midBN + midActivation + regularOut = builder.addBatchNorm(regularOut, blockDesc.midBN); + regularOut = builder.addActivation(regularOut, blockDesc.midActivation.activation); + + // finalConv + regularOut = builder.addConv(regularOut, blockDesc.finalConv); + + // Add residual + return builder.main_module->add_instruction(migraphx::make_op("add"), regularOut, residual); +} + +// Build complete MIGraphX program from ModelDesc +static migraphx::program buildMIGraphXProgram( + const ModelDesc& modelDesc, + int maxBatchSize, + int nnXLen, + int nnYLen, + bool useFP16, + bool useNHWC +) { + migraphx::program prog; + auto main_module = prog.get_main_module(); + + migraphx::shape::type_t dataType = useFP16 ? migraphx::shape::half_type : migraphx::shape::float_type; + + int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelDesc.modelVersion); + int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelDesc.modelVersion); + int numMetaFeatures = modelDesc.numInputMetaChannels; + + // Create input parameters + vector inputShape = {(size_t)maxBatchSize, (size_t)numSpatialFeatures, (size_t)nnYLen, (size_t)nnXLen}; + vector inputGlobalShape = {(size_t)maxBatchSize, (size_t)numGlobalFeatures}; + + // Input parameters are always float_type (host buffers are float). + // If using FP16, we convert to half inside the graph so MIGraphX handles conversion on GPU. + auto inputSpatial = main_module->add_parameter("input_spatial", migraphx::shape(migraphx::shape::float_type, inputShape)); + auto inputGlobal = main_module->add_parameter("input_global", migraphx::shape(migraphx::shape::float_type, inputGlobalShape)); + + // MIGraphX backend uses NCHW format only + (void)useNHWC; // Silently ignore NHWC setting + + MIGraphXGraphBuilder builder(main_module, dataType, maxBatchSize, nnXLen, nnYLen); + + // Convert inputs to computation type if using FP16 + if(useFP16) { + inputSpatial = builder.addConvert(inputSpatial, dataType); + inputGlobal = builder.addConvert(inputGlobal, dataType); + } + + // Build trunk + auto trunk = inputSpatial; + const TrunkDesc& trunkDesc = modelDesc.trunk; + + // Initial conv + if(trunkDesc.initialConv.outChannels > 0 && trunkDesc.initialConv.inChannels == numSpatialFeatures) { + trunk = builder.addConv(trunk, trunkDesc.initialConv); + } else if(trunkDesc.initialConv.outChannels > 0) { + cout << "MIGraphX: Skipping initialConv (input channel mismatch)" << endl; + } + + // Initial MatMul for global features + if(trunkDesc.initialMatMul.outChannels > 0) { + auto globalProcessed = builder.addMatMul(inputGlobal, trunkDesc.initialMatMul); + // Broadcast global features from [N, C] to spatial dimensions [N, C, H, W] + auto trunkShape = trunk->get_shape().lens(); + auto globalUnsqueezed = main_module->add_instruction( + migraphx::make_op("unsqueeze", {{"axes", migraphx::value(vector{2, 3})}}), globalProcessed); + auto globalBroadcast = main_module->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", trunkShape}}), globalUnsqueezed); + trunk = main_module->add_instruction(migraphx::make_op("add"), trunk, globalBroadcast); + } + + // SGF Metadata encoder (if enabled) - disabled for now due to potential weight shape issues + if(trunkDesc.metaEncoderVersion > 0 && numMetaFeatures > 0) { + // Skip SGF metadata encoder for now + cout << "MIGraphX: SGF Metadata encoder disabled" << endl; + } + + // Residual blocks using the stack builder + trunk = buildResidualBlockStack(builder, trunk, trunkDesc.blocks, "trunk"); + + // trunkTipBN + trunkTipActivation + trunk = builder.addBatchNorm(trunk, trunkDesc.trunkTipBN); + trunk = builder.addActivation(trunk, trunkDesc.trunkTipActivation.activation); + + // ======== Policy Head ======== + const PolicyHeadDesc& policyDesc = modelDesc.policyHead; + + migraphx::instruction_ref policy = trunk; + migraphx::instruction_ref policyPass = trunk; // will be overwritten + + if(policyDesc.p1Conv.outChannels > 0) { + // p1Conv branch (spatial policy) + auto p1Conv = builder.addConv(trunk, policyDesc.p1Conv); + + // g1Conv branch for global pooling + auto g1Conv = builder.addConv(trunk, policyDesc.g1Conv); + g1Conv = builder.addBatchNorm(g1Conv, policyDesc.g1BN); + g1Conv = builder.addActivation(g1Conv, policyDesc.g1Activation.activation); + + // Global pool: [batch, g1C, H, W] -> [batch, g1C*3] + auto gpool = builder.addGPool(g1Conv, false); + + // gpoolToBiasMul: [batch, g1C*3] -> [batch, p1C] bias + auto gpoolBias = builder.addMatMul(gpool, policyDesc.gpoolToBiasMul); + + // Broadcast bias and add to p1Conv + auto p1Shape = p1Conv->get_shape().lens(); + auto biasUnsqueezed = main_module->add_instruction( + migraphx::make_op("unsqueeze", {{"axes", migraphx::value(vector{2, 3})}}), gpoolBias); + auto biasBroadcast = main_module->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", p1Shape}}), biasUnsqueezed); + policy = main_module->add_instruction(migraphx::make_op("add"), p1Conv, biasBroadcast); + + policy = builder.addBatchNorm(policy, policyDesc.p1BN); + policy = builder.addActivation(policy, policyDesc.p1Activation.activation); + + // p2Conv -> spatial policy logits + if(policyDesc.p2Conv.outChannels > 0) { + policy = builder.addConv(policy, policyDesc.p2Conv); + } + + // Flatten spatial policy: [batch, numPolicyChannels, H, W] -> [batch, numPolicyChannels*H*W] + policy = builder.addFlatten(policy); + + // Pass policy (separate path from spatial, uses same gpool) + // gpoolToPassMul: [batch, g1C*3] -> passHidden + policyPass = builder.addMatMul(gpool, policyDesc.gpoolToPassMul, &policyDesc.gpoolToPassBias); + policyPass = builder.addActivation(policyPass, policyDesc.passActivation.activation); + + // gpoolToPassMul2: passHidden -> [batch, numPolicyChannels] (for modelVersion >= 15) + if(policyDesc.gpoolToPassMul2.outChannels > 0) { + policyPass = builder.addMatMul(policyPass, policyDesc.gpoolToPassMul2); + } + } else { + policy = builder.addFlatten(trunk); + // Zero pass policy fallback + vector zeroPass(modelDesc.numPolicyChannels, 0.0f); + policyPass = builder.addLiteral(zeroPass, {1, (size_t)modelDesc.numPolicyChannels}); + policyPass = main_module->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", vector{(size_t)maxBatchSize, (size_t)modelDesc.numPolicyChannels}}}), policyPass); + } + + // ======== Value Head ======== + const ValueHeadDesc& valueDesc = modelDesc.valueHead; + + // v1Conv + v1BN + v1Activation + auto v1Out = builder.addConv(trunk, valueDesc.v1Conv); + v1Out = builder.addBatchNorm(v1Out, valueDesc.v1BN); + v1Out = builder.addActivation(v1Out, valueDesc.v1Activation.activation); + + // Ownership branch: v1Out -> vOwnershipConv -> flatten (no tanh - matches CUDA backend) + auto ownership = builder.addConv(v1Out, valueDesc.vOwnershipConv); + ownership = builder.addFlatten(ownership); + + // Value branch: v1Out -> GPool (value head style) -> v2Mul + v2Bias + v2Activation -> v3Mul + v3Bias + auto vGpool = builder.addGPool(v1Out, true); // value head: mean, scale1, scale2 + + auto v2 = builder.addMatMul(vGpool, valueDesc.v2Mul, &valueDesc.v2Bias); + v2 = builder.addActivation(v2, valueDesc.v2Activation.activation); + + auto valueOut = builder.addMatMul(v2, valueDesc.v3Mul, &valueDesc.v3Bias); + + // Score value branch: same v2 -> sv3Mul + sv3Bias + auto scoreValue = builder.addMatMul(v2, valueDesc.sv3Mul, &valueDesc.sv3Bias); + + // Set outputs: policy, policyPass, value, scoreValue, ownership + if(modelDesc.modelVersion >= 2) { + main_module->add_return({policy, policyPass, valueOut, scoreValue, ownership}); + } else { + main_module->add_return({policy, valueOut}); + } + + return prog; +} + +//------------------------ Backend Structures ------------------------ + +struct LoadedModelInternal { + ModelDesc modelDesc; + string modelFile; + string expectedSha256; + + LoadedModelInternal(const string& file, const string& sha256) : modelFile(file), expectedSha256(sha256) { + ModelDesc::loadFromFileMaybeGZipped(file, modelDesc, sha256); + modelDesc.applyScale8ToReduceActivations(); + } +}; + +struct ComputeContextInternal { + int nnXLen, nnYLen; + enabled_t useFP16Mode; + enabled_t useNHWCMode; + string homeDataDir; + vector gpuIdxs; +}; + +struct ComputeHandleInternal { + unique_ptr model; + int maxBatchSize; + int gpuIdx; + bool requireExactNNLen; + bool inputsUseNHWC; + int nnXLen, nnYLen; +}; + +struct InputBuffersInternal { + int maxBatchSize; + int nnXLen, nnYLen; + + size_t singleInputElts; + size_t singleInputBytes; + size_t singleInputGlobalElts; + size_t singleInputGlobalBytes; + size_t singleInputMetaElts; + size_t singleInputMetaBytes; + + size_t userInputBufferBytes; + size_t userInputGlobalBufferBytes; + size_t userInputMetaBufferBytes; + + vector userInputBuffer; + vector userInputGlobalBuffer; + vector userInputMetaBuffer; + + size_t singlePolicyResultElts; + size_t singlePolicyResultBytes; + size_t singlePolicyPassResultElts; + size_t singlePolicyPassResultBytes; + size_t singleValueResultElts; + size_t singleValueResultBytes; + size_t singleScoreValueResultElts; + size_t singleScoreValueResultBytes; + size_t singleOwnershipResultElts; + size_t singleOwnershipResultBytes; + + vector policyResults; + vector policyPassResults; + vector valueResults; + vector scoreValueResults; + vector ownershipResults; + + size_t policyResultBufferBytes; + size_t policyPassResultBufferBytes; + size_t valueResultBufferBytes; + size_t scoreValueResultBufferBytes; + size_t ownershipResultBufferBytes; +}; + +//------------------------ NeuralNet Implementation ------------------------ + +namespace NeuralNet { + +void globalInitialize() {} +void globalCleanup() {} + +void printDevices() { + cout << "MIGraphX Backend: AMD GPU via MIGraphX" << endl; +} + +LoadedModel* loadModelFile(const string& file, const string& expectedSha256) { + return reinterpret_cast(new LoadedModelInternal(file, expectedSha256)); +} + +void freeLoadedModel(LoadedModel* loadedModel) { + if(loadedModel) { + LoadedModelInternal* model = reinterpret_cast(loadedModel); + delete model; + } +} + +const ModelDesc& getModelDesc(const LoadedModel* loadedModel) { + return reinterpret_cast(loadedModel)->modelDesc; +} + +ComputeContext* createComputeContext( + const vector& gpuIdxs, + Logger* logger, + int nnXLen, + int nnYLen, + const string& openCLTunerFile, + const string& homeDataDirOverride, + bool openCLReTunePerBoardSize, + enabled_t useFP16Mode, + enabled_t useNHWCMode, + const LoadedModel* loadedModel +) { + (void)logger; + (void)openCLTunerFile; + (void)homeDataDirOverride; + (void)openCLReTunePerBoardSize; + (void)loadedModel; + + auto context = new ComputeContextInternal(); + context->gpuIdxs = gpuIdxs; + context->nnXLen = nnXLen; + context->nnYLen = nnYLen; + context->useFP16Mode = useFP16Mode; + context->useNHWCMode = useNHWCMode; + + return reinterpret_cast(context); +} + +void freeComputeContext(ComputeContext* computeContext) { + if(computeContext) { + ComputeContextInternal* context = reinterpret_cast(computeContext); + delete context; + } +} + +// Static mutex for cache operations +static mutex migraphxCacheMutex; + +// Generate batch sizes to compile for MIGraphX (no dynamic batch support). +static vector generateBatchSizes(int maxBatchSize) { + vector candidates = {4, 8, 16, 24, 32, 40, 64}; + + // Keep only sizes <= maxBatchSize, always include maxBatchSize itself + vector sizes; + for(int s : candidates) { + if(s <= maxBatchSize) + sizes.push_back(s); + } + if(sizes.empty() || sizes.back() != maxBatchSize) + sizes.push_back(maxBatchSize); + return sizes; +} + +// Generate cache file path +static string getCacheFilePath( + const string& homeDataDir, + const ModelDesc& modelDesc, + int nnXLen, + int nnYLen, + int maxBatchSize, + bool useFP16, + bool useNHWC, + bool requireExactNNLen +) { + auto cacheDir = HomeData::getHomeDataDir(true, homeDataDir); + cacheDir += "/migraphxcache"; + + // Create directory if not exists + MakeDir::make(cacheDir); + + // Generate unique cache key based on model and parameters + string cacheKey = Global::strprintf( + "migraphx_%s_%s_%dx%d_batch%d_fp%d_nhwc%d_%s", + modelDesc.name.c_str(), + modelDesc.sha256.substr(0, 16).c_str(), + nnYLen, + nnXLen, + maxBatchSize, + useFP16 ? 1 : 0, + useNHWC ? 1 : 0, + requireExactNNLen ? "exact" : "max" + ); + + return cacheDir + "/" + cacheKey + ".mxr"; +} + +ComputeHandle* createComputeHandle( + ComputeContext* context, + const LoadedModel* loadedModel, + Logger* logger, + int maxBatchSize, + bool requireExactNNLen, + bool inputsUseNHWC, + int gpuIdxForThisThread, + int serverThreadIdx +) { + (void)serverThreadIdx; + + ComputeContextInternal* ctx = reinterpret_cast(context); + const LoadedModelInternal* model = reinterpret_cast(loadedModel); + + auto handle = new ComputeHandleInternal(); + handle->maxBatchSize = maxBatchSize; + handle->gpuIdx = gpuIdxForThisThread; + handle->requireExactNNLen = requireExactNNLen; + handle->inputsUseNHWC = inputsUseNHWC; + handle->nnXLen = ctx->nnXLen; + handle->nnYLen = ctx->nnYLen; + + bool useFP16 = (ctx->useFP16Mode == enabled_t::True || ctx->useFP16Mode == enabled_t::Auto); + bool useNHWC = (ctx->useNHWCMode == enabled_t::True); + + // MIGraphX backend only supports NCHW format + if(useNHWC) { + cout << "MIGraphX: WARNING: NHWC format is not supported, forcing NCHW" << endl; + useNHWC = false; + } + + handle->model = make_unique(); + handle->model->modelVersion = model->modelDesc.modelVersion; + handle->model->maxBatchSize = maxBatchSize; + handle->model->nnXLen = ctx->nnXLen; + handle->model->nnYLen = ctx->nnYLen; + handle->model->useFP16 = useFP16; + handle->model->useNHWC = false; // Always NCHW + + handle->model->numInputChannels = model->modelDesc.numInputChannels; + handle->model->numInputGlobalChannels = model->modelDesc.numInputGlobalChannels; + handle->model->numInputMetaChannels = model->modelDesc.numInputMetaChannels; + handle->model->numPolicyChannels = model->modelDesc.numPolicyChannels; + handle->model->numValueChannels = model->modelDesc.numValueChannels; + handle->model->numScoreValueChannels = model->modelDesc.numScoreValueChannels; + handle->model->numOwnershipChannels = model->modelDesc.numOwnershipChannels; + + // Generate batch sizes to compile + vector batchSizesToCompile = generateBatchSizes(maxBatchSize); + handle->model->batchSizes = batchSizesToCompile; + handle->model->tgt = migraphx::make_target("gpu"); + + lock_guard cacheLock(migraphxCacheMutex); + + for(int bs : batchSizesToCompile) { + // Generate cache file path for this batch size + string cacheFile = getCacheFilePath( + ctx->homeDataDir, + model->modelDesc, + ctx->nnXLen, + ctx->nnYLen, + bs, + useFP16, + useNHWC, + requireExactNNLen + ); + + bool cacheLoaded = false; + + // Try to load from cache + if(FileUtils::exists(cacheFile)) { + try { + if(logger) { + logger->write("MIGraphX: Loading compiled program from cache (batch " + Global::intToString(bs) + "): " + cacheFile); + } + cout << "MIGraphX: Loading batch " << bs << " from cache..." << endl; + + handle->model->progs[bs] = migraphx::load(cacheFile); + cacheLoaded = true; + + cout << "MIGraphX: Batch " << bs << " loaded! (FP16: " << (useFP16 ? "yes" : "no") << ")" << endl; + } catch(const exception& e) { + if(logger) { + logger->write(string("MIGraphX: Cache load failed for batch ") + Global::intToString(bs) + ": " + e.what()); + } + cout << "MIGraphX: Cache load failed for batch " << bs << ", rebuilding..." << endl; + } + } + + if(!cacheLoaded) { + cout << "MIGraphX: Building model (version " << model->modelDesc.modelVersion << ")..." << endl; + cout << " Board size: " << ctx->nnXLen << "x" << ctx->nnYLen << endl; + cout << " Batch size: " << bs << endl; + cout << " FP16: " << (useFP16 ? "yes" : "no") << endl; + cout << " NHWC: " << (useNHWC ? "yes" : "no") << endl; + cout << " Trunk channels: " << model->modelDesc.trunk.trunkNumChannels << endl; + cout << " Num blocks: " << model->modelDesc.trunk.numBlocks << endl; + + handle->model->progs[bs] = buildMIGraphXProgram( + model->modelDesc, + bs, + ctx->nnXLen, + ctx->nnYLen, + useFP16, + useNHWC + ); + + cout << "MIGraphX: Compiling batch " << bs << "..." << endl; + migraphx::compile_options compile_opts; + compile_opts.offload_copy = true; + + handle->model->progs[bs].compile(handle->model->tgt, compile_opts); + + cout << "MIGraphX: Batch " << bs << " compiled!" << endl; + + // Save to cache + try { + if(logger) { + logger->write("MIGraphX: Saving compiled program to cache (batch " + Global::intToString(bs) + "): " + cacheFile); + } + migraphx::save(handle->model->progs[bs], cacheFile); + cout << "MIGraphX: Batch " << bs << " cached!" << endl; + } catch(const exception& e) { + if(logger) { + logger->write(string("MIGraphX: Cache save failed: ") + e.what()); + } + cout << "MIGraphX: Cache save failed: " << e.what() << endl; + } + } + } + + cout << "MIGraphX: All " << batchSizesToCompile.size() << " batch sizes ready: "; + for(size_t i = 0; i < batchSizesToCompile.size(); i++) { + if(i > 0) cout << ", "; + cout << batchSizesToCompile[i]; + } + cout << endl; + + return reinterpret_cast(handle); +} + +void freeComputeHandle(ComputeHandle* computeHandle) { + if(computeHandle) { + ComputeHandleInternal* handle = reinterpret_cast(computeHandle); + delete handle; + } +} + +bool isUsingFP16(const ComputeHandle* computeHandle) { + const ComputeHandleInternal* handle = reinterpret_cast(computeHandle); + return handle->model->useFP16; +} + +InputBuffers* createInputBuffers(const LoadedModel* loadedModel, int maxBatchSize, int nnXLen, int nnYLen) { + const ModelDesc& m = getModelDesc(loadedModel); + + auto buffers = new InputBuffersInternal(); + buffers->maxBatchSize = maxBatchSize; + buffers->nnXLen = nnXLen; + buffers->nnYLen = nnYLen; + + int modelVersion = m.modelVersion; + int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion); + int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion); + int numMetaFeatures = m.numInputMetaChannels; + + buffers->singleInputElts = (size_t)numSpatialFeatures * nnXLen * nnYLen; + buffers->singleInputBytes = buffers->singleInputElts * sizeof(float); + buffers->singleInputGlobalElts = numGlobalFeatures; + buffers->singleInputGlobalBytes = buffers->singleInputGlobalElts * sizeof(float); + buffers->singleInputMetaElts = numMetaFeatures; + buffers->singleInputMetaBytes = buffers->singleInputMetaElts * sizeof(float); + + buffers->userInputBufferBytes = buffers->singleInputBytes * maxBatchSize; + buffers->userInputGlobalBufferBytes = buffers->singleInputGlobalBytes * maxBatchSize; + buffers->userInputMetaBufferBytes = buffers->singleInputMetaBytes * maxBatchSize; + + buffers->userInputBuffer.resize(buffers->singleInputElts * maxBatchSize, 0.0f); + buffers->userInputGlobalBuffer.resize(buffers->singleInputGlobalElts * maxBatchSize, 0.0f); + buffers->userInputMetaBuffer.resize(buffers->singleInputMetaElts * maxBatchSize, 0.0f); + + buffers->singlePolicyResultElts = m.numPolicyChannels * nnXLen * nnYLen; + buffers->singlePolicyResultBytes = buffers->singlePolicyResultElts * sizeof(float); + buffers->singlePolicyPassResultElts = m.numPolicyChannels; + buffers->singlePolicyPassResultBytes = buffers->singlePolicyPassResultElts * sizeof(float); + + buffers->singleValueResultElts = m.numValueChannels; + buffers->singleValueResultBytes = buffers->singleValueResultElts * sizeof(float); + buffers->singleScoreValueResultElts = max(1, m.numScoreValueChannels); + buffers->singleScoreValueResultBytes = buffers->singleScoreValueResultElts * sizeof(float); + buffers->singleOwnershipResultElts = nnXLen * nnYLen; + buffers->singleOwnershipResultBytes = buffers->singleOwnershipResultElts * sizeof(float); + + buffers->policyResultBufferBytes = buffers->singlePolicyResultBytes * maxBatchSize; + buffers->policyPassResultBufferBytes = buffers->singlePolicyPassResultBytes * maxBatchSize; + buffers->valueResultBufferBytes = buffers->singleValueResultBytes * maxBatchSize; + buffers->scoreValueResultBufferBytes = buffers->singleScoreValueResultBytes * maxBatchSize; + buffers->ownershipResultBufferBytes = buffers->singleOwnershipResultBytes * maxBatchSize; + + buffers->policyResults.resize(buffers->singlePolicyResultElts * maxBatchSize, 0.0f); + buffers->policyPassResults.resize(buffers->singlePolicyPassResultElts * maxBatchSize, 0.0f); + buffers->valueResults.resize(buffers->singleValueResultElts * maxBatchSize, 0.0f); + buffers->scoreValueResults.resize(buffers->singleScoreValueResultElts * maxBatchSize, 0.0f); + buffers->ownershipResults.resize(buffers->singleOwnershipResultElts * maxBatchSize, 0.0f); + + return reinterpret_cast(buffers); +} + +void freeInputBuffers(InputBuffers* buffers) { + if(buffers) { + InputBuffersInternal* data = reinterpret_cast(buffers); + delete data; + } +} + +void getOutput( + ComputeHandle* computeHandle, + InputBuffers* inputBuffers, + int numBatchEltsFilled, + NNResultBuf** inputBufs, + vector& outputs +) { + ComputeHandleInternal* handle = reinterpret_cast(computeHandle); + InputBuffersInternal* buffers = reinterpret_cast(inputBuffers); + + assert(numBatchEltsFilled <= buffers->maxBatchSize); + assert(numBatchEltsFilled > 0); + + int batchSize = numBatchEltsFilled; + int nnXLen = handle->nnXLen; + int nnYLen = handle->nnYLen; + int modelVersion = handle->model->modelVersion; + + int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion); + int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion); + int numMetaFeatures = handle->model->numInputMetaChannels; + + // Copy inputs + for(int nIdx = 0; nIdx < batchSize; nIdx++) { + float* rowSpatialInput = buffers->userInputBuffer.data() + (buffers->singleInputElts * nIdx); + float* rowGlobalInput = buffers->userInputGlobalBuffer.data() + (buffers->singleInputGlobalElts * nIdx); + float* rowMetaInput = buffers->userInputMetaBuffer.data() + (buffers->singleInputMetaElts * nIdx); + + const float* rowGlobal = inputBufs[nIdx]->rowGlobalBuf.data(); + const float* rowSpatial = inputBufs[nIdx]->rowSpatialBuf.data(); + const float* rowMeta = inputBufs[nIdx]->rowMetaBuf.data(); + bool hasRowMeta = inputBufs[nIdx]->hasRowMeta; + + std::copy(rowGlobal, rowGlobal + numGlobalFeatures, rowGlobalInput); + if(numMetaFeatures > 0) { + assert(rowMeta != NULL); + assert(hasRowMeta); + std::copy(rowMeta, rowMeta + numMetaFeatures, rowMetaInput); + } + + SymmetryHelpers::copyInputsWithSymmetry( + rowSpatial, rowSpatialInput, 1, nnYLen, nnXLen, numSpatialFeatures, + handle->inputsUseNHWC, inputBufs[nIdx]->symmetry + ); + } + + // Run inference - pick the smallest compiled batch size that fits + int bestBatchSize = handle->model->getBestBatchSize(batchSize); + migraphx::parameter_map params; + + // Always use float_type for input shapes - host buffers are float, graph handles conversion + migraphx::shape input_shape( + migraphx::shape::float_type, + {(size_t)bestBatchSize, (size_t)numSpatialFeatures, (size_t)nnYLen, (size_t)nnXLen} + ); + params["input_spatial"] = migraphx::argument(input_shape, buffers->userInputBuffer.data()); + + migraphx::shape global_shape( + migraphx::shape::float_type, + {(size_t)bestBatchSize, (size_t)numGlobalFeatures} + ); + params["input_global"] = migraphx::argument(global_shape, buffers->userInputGlobalBuffer.data()); + + auto results = handle->model->getProgram(bestBatchSize).eval(params); + + // Extract results from MIGraphX eval into buffers + // Output order for modelVersion >= 2: policy, policyPass, value, scoreValue, ownership + int numPolicyChannels = handle->model->numPolicyChannels; + size_t policySize = (size_t)numPolicyChannels * nnXLen * nnYLen; + int numValueChannels = handle->model->numValueChannels; + int numScoreValueChannels = handle->model->numScoreValueChannels; + size_t ownershipSize = (size_t)nnXLen * nnYLen; + + // Policy: [maxBatchSize, numPolicyChannels * H * W] + if(results.size() > 0) { + results[0].visit([&](auto output) { + for(int row = 0; row < batchSize; row++) { + for(size_t i = 0; i < policySize; i++) { + buffers->policyResults[row * policySize + i] = static_cast(output[row * policySize + i]); + } + } + }); + } + + if(modelVersion >= 2) { + // Policy pass: [maxBatchSize, numPolicyChannels] + if(results.size() > 1) { + results[1].visit([&](auto output) { + for(int row = 0; row < batchSize; row++) { + for(int i = 0; i < numPolicyChannels; i++) { + buffers->policyPassResults[row * numPolicyChannels + i] = static_cast(output[row * numPolicyChannels + i]); + } + } + }); + } + + // Value: [maxBatchSize, numValueChannels] + if(results.size() > 2) { + results[2].visit([&](auto output) { + for(int row = 0; row < batchSize; row++) { + for(int i = 0; i < numValueChannels; i++) { + buffers->valueResults[row * numValueChannels + i] = static_cast(output[row * numValueChannels + i]); + } + } + }); + } + + // Score value: [maxBatchSize, numScoreValueChannels] + if(results.size() > 3) { + results[3].visit([&](auto output) { + for(int row = 0; row < batchSize; row++) { + for(int i = 0; i < numScoreValueChannels; i++) { + buffers->scoreValueResults[row * numScoreValueChannels + i] = static_cast(output[row * numScoreValueChannels + i]); + } + } + }); + } + + // Ownership: [maxBatchSize, H * W] + if(results.size() > 4) { + results[4].visit([&](auto output) { + for(int row = 0; row < batchSize; row++) { + for(size_t i = 0; i < ownershipSize; i++) { + buffers->ownershipResults[row * ownershipSize + i] = static_cast(output[row * ownershipSize + i]); + } + } + }); + } + } else { + // Value: [maxBatchSize, numValueChannels] + if(results.size() > 1) { + results[1].visit([&](auto output) { + for(int row = 0; row < batchSize; row++) { + for(int i = 0; i < numValueChannels; i++) { + buffers->valueResults[row * numValueChannels + i] = static_cast(output[row * numValueChannels + i]); + } + } + }); + } + } + + // Process outputs per row + assert(outputs.size() == (size_t)batchSize); + + float policyProbsTmp[NNPos::MAX_NN_POLICY_SIZE]; + + for(int row = 0; row < batchSize; row++) { + NNOutput* output = outputs[row]; + assert(output->nnXLen == nnXLen); + assert(output->nnYLen == nnYLen); + float policyOptimism = (float)inputBufs[row]->policyOptimism; + + const float* policyPassSrcBuf = buffers->policyPassResults.data() + row * numPolicyChannels; + const float* policySrcBuf = buffers->policyResults.data() + row * policySize; + float* policyProbs = output->policyProbs; + + if(numPolicyChannels == 2 || (numPolicyChannels == 4 && modelVersion >= 16)) { + for(int i = 0; i < nnXLen * nnYLen; i++) { + float p = policySrcBuf[i]; + float pOpt = policySrcBuf[i + nnXLen * nnYLen]; + policyProbsTmp[i] = p + (pOpt - p) * policyOptimism; + } + SymmetryHelpers::copyOutputsWithSymmetry( + policyProbsTmp, policyProbs, 1, nnYLen, nnXLen, inputBufs[row]->symmetry + ); + policyProbs[nnXLen * nnYLen] = policyPassSrcBuf[0] + (policyPassSrcBuf[1] - policyPassSrcBuf[0]) * policyOptimism; + } else { + assert(numPolicyChannels == 1); + SymmetryHelpers::copyOutputsWithSymmetry( + policySrcBuf, policyProbs, 1, nnYLen, nnXLen, inputBufs[row]->symmetry + ); + policyProbs[nnXLen * nnYLen] = policyPassSrcBuf[0]; + } + + assert(numValueChannels == 3); + output->whiteWinProb = buffers->valueResults[row * numValueChannels]; + output->whiteLossProb = buffers->valueResults[row * numValueChannels + 1]; + output->whiteNoResultProb = buffers->valueResults[row * numValueChannels + 2]; + + if(output->whiteOwnerMap != NULL) { + const float* ownershipSrcBuf = buffers->ownershipResults.data() + row * ownershipSize; + assert(handle->model->numOwnershipChannels == 1); + SymmetryHelpers::copyOutputsWithSymmetry(ownershipSrcBuf, output->whiteOwnerMap, 1, nnYLen, nnXLen, inputBufs[row]->symmetry); + } + + if(modelVersion >= 9) { + assert(numScoreValueChannels == 6); + output->whiteScoreMean = buffers->scoreValueResults[row * numScoreValueChannels]; + output->whiteScoreMeanSq = buffers->scoreValueResults[row * numScoreValueChannels + 1]; + output->whiteLead = buffers->scoreValueResults[row * numScoreValueChannels + 2]; + output->varTimeLeft = buffers->scoreValueResults[row * numScoreValueChannels + 3]; + output->shorttermWinlossError = buffers->scoreValueResults[row * numScoreValueChannels + 4]; + output->shorttermScoreError = buffers->scoreValueResults[row * numScoreValueChannels + 5]; + } else if(modelVersion >= 8) { + assert(numScoreValueChannels == 4); + output->whiteScoreMean = buffers->scoreValueResults[row * numScoreValueChannels]; + output->whiteScoreMeanSq = buffers->scoreValueResults[row * numScoreValueChannels + 1]; + output->whiteLead = buffers->scoreValueResults[row * numScoreValueChannels + 2]; + output->varTimeLeft = buffers->scoreValueResults[row * numScoreValueChannels + 3]; + output->shorttermWinlossError = 0.0f; + output->shorttermScoreError = 0.0f; + } else if(modelVersion >= 4) { + assert(numScoreValueChannels == 2); + output->whiteScoreMean = buffers->scoreValueResults[row * numScoreValueChannels]; + output->whiteScoreMeanSq = buffers->scoreValueResults[row * numScoreValueChannels + 1]; + output->whiteLead = output->whiteScoreMean; + output->varTimeLeft = 0.0f; + output->shorttermWinlossError = 0.0f; + output->shorttermScoreError = 0.0f; + } else if(modelVersion >= 3) { + assert(numScoreValueChannels == 1); + output->whiteScoreMean = buffers->scoreValueResults[row * numScoreValueChannels]; + output->whiteScoreMeanSq = output->whiteScoreMean * output->whiteScoreMean; + output->whiteLead = output->whiteScoreMean; + output->varTimeLeft = 0.0f; + output->shorttermWinlossError = 0.0f; + output->shorttermScoreError = 0.0f; + } else { + output->whiteScoreMean = 0.0f; + output->whiteScoreMeanSq = 1.0f; + output->whiteLead = 0.0f; + output->varTimeLeft = 0.0f; + output->shorttermWinlossError = 0.0f; + output->shorttermScoreError = 0.0f; + } + + output->policyOptimismUsed = policyOptimism; + } +} + +// Test functions - implemented using MIGraphX for layer verification +bool testEvaluateConv( + const ConvLayerDesc* desc, + int batchSize, + int nnXLen, + int nnYLen, + bool useFP16, + bool useNHWC, + const vector& inputBuffer, + vector& outputBuffer +) { + // Skip NHWC tests - MIGraphX backend uses NCHW format + if(useNHWC) + return false; + + try { + migraphx::program prog; + auto main_module = prog.get_main_module(); + + migraphx::shape::type_t dataType = useFP16 ? migraphx::shape::half_type : migraphx::shape::float_type; + vector inputShape = {(size_t)batchSize, (size_t)desc->inChannels, (size_t)nnYLen, (size_t)nnXLen}; + + auto input = main_module->add_parameter("input", migraphx::shape(dataType, inputShape)); + + // Create weights - MIGraphX expects float data, will convert internally + vector wShape = {(size_t)desc->outChannels, (size_t)desc->inChannels, (size_t)desc->convYSize, (size_t)desc->convXSize}; + migraphx::shape wShapeDesc(dataType, wShape); + auto weights = main_module->add_literal(migraphx::literal(wShapeDesc, desc->weights)); + + // Convolution + int padY = (desc->convYSize - 1) / 2 * desc->dilationY; + int padX = (desc->convXSize - 1) / 2 * desc->dilationX; + vector padding = {(size_t)padY, (size_t)padX}; + vector stride = {1, 1}; + vector dilation = {(size_t)desc->dilationY, (size_t)desc->dilationX}; + + auto conv_op = migraphx::make_op("convolution", { + {"padding", migraphx::value(padding)}, + {"stride", migraphx::value(stride)}, + {"dilation", migraphx::value(dilation)}, + {"group", 1} + }); + + auto conv = main_module->add_instruction(conv_op, input, weights); + main_module->add_return({conv}); + + // Compile and run + migraphx::compile_options compile_opts; + compile_opts.offload_copy = true; + auto target = migraphx::make_target("gpu"); + prog.compile(target, compile_opts); + + migraphx::parameter_map params; + + // For FP16, we need to convert input data to half precision + vector halfInput; + if(useFP16) { + halfInput.resize(inputBuffer.size()); + for(size_t i = 0; i < inputBuffer.size(); i++) { + halfInput[i] = migraphx::half(inputBuffer[i]); + } + params["input"] = migraphx::argument(migraphx::shape(dataType, inputShape), halfInput.data()); + } else { + params["input"] = migraphx::argument(migraphx::shape(dataType, inputShape), const_cast(inputBuffer.data())); + } + + auto results = prog.eval(params); + + // Copy output + vector outputShape = {(size_t)batchSize, (size_t)desc->outChannels, (size_t)nnYLen, (size_t)nnXLen}; + size_t outputSize = batchSize * desc->outChannels * nnYLen * nnXLen; + outputBuffer.resize(outputSize); + + auto outputArg = results[0]; + if(useFP16) { + // Convert half output back to float + outputArg.visit([&](auto output) { + for(size_t i = 0; i < outputSize; i++) { + outputBuffer[i] = static_cast(output[i]); + } + }); + } else { + vector tempOutput(outputSize); + outputArg.visit([&](auto output) { + for(size_t i = 0; i < outputSize; i++) { + tempOutput[i] = static_cast(output[i]); + } + }); + outputBuffer = tempOutput; + } + + return true; + } catch(const exception& e) { + cerr << "testEvaluateConv failed: " << e.what() << endl; + return false; + } +} + +bool testEvaluateBatchNorm( + const BatchNormLayerDesc* desc, + int batchSize, + int nnXLen, + int nnYLen, + bool useFP16, + bool useNHWC, + const vector& inputBuffer, + const vector& maskBuffer, + vector& outputBuffer +) { + (void)maskBuffer; // BatchNorm doesn't use mask directly + + // Skip NHWC tests - MIGraphX backend uses NCHW format + if(useNHWC) + return false; + + // Validate weights are available + if(desc->mergedScale.size() != (size_t)desc->numChannels || desc->mergedBias.size() != (size_t)desc->numChannels) { + cerr << "BatchNorm test: weight size mismatch, skipping" << endl; + return false; + } + + try { + migraphx::program prog; + auto main_module = prog.get_main_module(); + + migraphx::shape::type_t dataType = useFP16 ? migraphx::shape::half_type : migraphx::shape::float_type; + vector inputShape = {(size_t)batchSize, (size_t)desc->numChannels, (size_t)nnYLen, (size_t)nnXLen}; + + auto input = main_module->add_parameter("input", migraphx::shape(dataType, inputShape)); + + // Create merged scale and bias + vector paramShape = {(size_t)desc->numChannels}; + migraphx::shape paramDesc(dataType, paramShape); + + auto scale = main_module->add_literal(migraphx::literal(paramDesc, desc->mergedScale)); + auto bias = main_module->add_literal(migraphx::literal(paramDesc, desc->mergedBias)); + + // Broadcast scale and bias to input shape + vector broadcastShape = {1, (size_t)desc->numChannels, 1, 1}; + auto scale_broadcast = main_module->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", inputShape}}), scale); + auto bias_broadcast = main_module->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", inputShape}}), bias); + + // Apply scale and bias: y = x * scale + bias + auto scaled = main_module->add_instruction(migraphx::make_op("mul"), input, scale_broadcast); + auto result = main_module->add_instruction(migraphx::make_op("add"), scaled, bias_broadcast); + + main_module->add_return({result}); + + // Compile and run + migraphx::compile_options compile_opts; + compile_opts.offload_copy = true; + auto target = migraphx::make_target("gpu"); + prog.compile(target, compile_opts); + + migraphx::parameter_map params; + + // For FP16, we need to convert input data to half precision + vector halfInput; + if(useFP16) { + halfInput.resize(inputBuffer.size()); + for(size_t i = 0; i < inputBuffer.size(); i++) { + halfInput[i] = migraphx::half(inputBuffer[i]); + } + params["input"] = migraphx::argument(migraphx::shape(dataType, inputShape), halfInput.data()); + } else { + params["input"] = migraphx::argument(migraphx::shape(dataType, inputShape), const_cast(inputBuffer.data())); + } + + auto results = prog.eval(params); + + // Copy output + size_t outputSize = batchSize * desc->numChannels * nnYLen * nnXLen; + outputBuffer.resize(outputSize); + + auto outputArg = results[0]; + if(useFP16) { + outputArg.visit([&](auto output) { + for(size_t i = 0; i < outputSize; i++) { + outputBuffer[i] = static_cast(output[i]); + } + }); + } else { + vector tempOutput(outputSize); + outputArg.visit([&](auto output) { + for(size_t i = 0; i < outputSize; i++) { + tempOutput[i] = static_cast(output[i]); + } + }); + outputBuffer = tempOutput; + } + + return true; + } catch(const exception& e) { + cerr << "testEvaluateBatchNorm failed: " << e.what() << endl; + return false; + } +} + +bool testEvaluateResidualBlock( + const ResidualBlockDesc* desc, + int batchSize, + int nnXLen, + int nnYLen, + bool useFP16, + bool useNHWC, + const vector& inputBuffer, + const vector& maskBuffer, + vector& outputBuffer +) { + (void)maskBuffer; + + // Skip NHWC tests - MIGraphX backend uses NCHW format + if(useNHWC) + return false; + + // Validate weights are available + size_t w1Expected = (size_t)desc->regularConv.outChannels * desc->regularConv.inChannels + * desc->regularConv.convYSize * desc->regularConv.convXSize; + size_t w2Expected = (size_t)desc->finalConv.outChannels * desc->finalConv.inChannels + * desc->finalConv.convYSize * desc->finalConv.convXSize; + if(desc->regularConv.weights.size() != w1Expected || desc->finalConv.weights.size() != w2Expected) { + cerr << "ResidualBlock test: weight size mismatch, skipping" << endl; + return false; + } + + try { + migraphx::program prog; + auto main_module = prog.get_main_module(); + + migraphx::shape::type_t dataType = useFP16 ? migraphx::shape::half_type : migraphx::shape::float_type; + int numChannels = desc->regularConv.inChannels; + vector inputShape = {(size_t)batchSize, (size_t)numChannels, (size_t)nnYLen, (size_t)nnXLen}; + + auto input = main_module->add_parameter("input", migraphx::shape(dataType, inputShape)); + + // Build residual block + auto residual = input; + + // preBN + preActivation (simplified - just activation for now) + auto x = input; + if(desc->preActivation.activation == 1) { // GELU + // Simplified GELU + auto sigmoid = main_module->add_instruction(migraphx::make_op("sigmoid"), x); + x = main_module->add_instruction(migraphx::make_op("mul"), x, sigmoid); + } else { + x = main_module->add_instruction(migraphx::make_op("relu"), x); + } + + // regularConv + vector w1Shape = {(size_t)desc->regularConv.outChannels, (size_t)desc->regularConv.inChannels, + (size_t)desc->regularConv.convYSize, (size_t)desc->regularConv.convXSize}; + migraphx::shape w1Desc(dataType, w1Shape); + auto w1 = main_module->add_literal(migraphx::literal(w1Desc, desc->regularConv.weights)); + + int pad1 = (desc->regularConv.convYSize - 1) / 2; + vector padding1 = {(size_t)pad1, (size_t)pad1}; + auto conv1_op = migraphx::make_op("convolution", { + {"padding", migraphx::value(padding1)}, + {"stride", migraphx::value(vector{1, 1})}, + {"dilation", migraphx::value(vector{(size_t)desc->regularConv.dilationY, (size_t)desc->regularConv.dilationX})}, + {"group", 1} + }); + x = main_module->add_instruction(conv1_op, x, w1); + + // midActivation + if(desc->midActivation.activation == 1) { + auto sigmoid = main_module->add_instruction(migraphx::make_op("sigmoid"), x); + x = main_module->add_instruction(migraphx::make_op("mul"), x, sigmoid); + } else { + x = main_module->add_instruction(migraphx::make_op("relu"), x); + } + + // finalConv + vector w2Shape = {(size_t)desc->finalConv.outChannels, (size_t)desc->finalConv.inChannels, + (size_t)desc->finalConv.convYSize, (size_t)desc->finalConv.convXSize}; + migraphx::shape w2Desc(dataType, w2Shape); + auto w2 = main_module->add_literal(migraphx::literal(w2Desc, desc->finalConv.weights)); + + int pad2 = (desc->finalConv.convYSize - 1) / 2; + vector padding2 = {(size_t)pad2, (size_t)pad2}; + auto conv2_op = migraphx::make_op("convolution", { + {"padding", migraphx::value(padding2)}, + {"stride", migraphx::value(vector{1, 1})}, + {"dilation", migraphx::value(vector{(size_t)desc->finalConv.dilationY, (size_t)desc->finalConv.dilationX})}, + {"group", 1} + }); + x = main_module->add_instruction(conv2_op, x, w2); + + // Add residual + auto result = main_module->add_instruction(migraphx::make_op("add"), x, residual); + + main_module->add_return({result}); + + // Compile and run + migraphx::compile_options compile_opts; + compile_opts.offload_copy = true; + auto target = migraphx::make_target("gpu"); + prog.compile(target, compile_opts); + + migraphx::parameter_map params; + + // For FP16, we need to convert input data to half precision + vector halfInput; + if(useFP16) { + halfInput.resize(inputBuffer.size()); + for(size_t i = 0; i < inputBuffer.size(); i++) { + halfInput[i] = migraphx::half(inputBuffer[i]); + } + params["input"] = migraphx::argument(migraphx::shape(dataType, inputShape), halfInput.data()); + } else { + params["input"] = migraphx::argument(migraphx::shape(dataType, inputShape), const_cast(inputBuffer.data())); + } + + auto results = prog.eval(params); + + // Copy output + size_t outputSize = batchSize * numChannels * nnYLen * nnXLen; + outputBuffer.resize(outputSize); + + auto outputArg = results[0]; + if(useFP16) { + outputArg.visit([&](auto output) { + for(size_t i = 0; i < outputSize; i++) { + outputBuffer[i] = static_cast(output[i]); + } + }); + } else { + vector tempOutput(outputSize); + outputArg.visit([&](auto output) { + for(size_t i = 0; i < outputSize; i++) { + tempOutput[i] = static_cast(output[i]); + } + }); + outputBuffer = tempOutput; + } + + return true; + } catch(const exception& e) { + cerr << "testEvaluateResidualBlock failed: " << e.what() << endl; + return false; + } +} + +bool testEvaluateGlobalPoolingResidualBlock( + const GlobalPoolingResidualBlockDesc* desc, + int batchSize, + int nnXLen, + int nnYLen, + bool useFP16, + bool useNHWC, + const vector& inputBuffer, + const vector& maskBuffer, + vector& outputBuffer +) { + (void)desc; + (void)batchSize; + (void)nnXLen; + (void)nnYLen; + (void)useFP16; + (void)useNHWC; + (void)inputBuffer; + (void)maskBuffer; + (void)outputBuffer; + + // Global pooling residual block tests not supported yet + return false; + + try { + migraphx::program prog; + auto main_module = prog.get_main_module(); + + migraphx::shape::type_t dataType = migraphx::shape::float_type; + int numChannels = desc->regularConv.inChannels; + vector inputShape = {(size_t)batchSize, (size_t)numChannels, (size_t)nnYLen, (size_t)nnXLen}; + + auto input = main_module->add_parameter("input", migraphx::shape(dataType, inputShape)); + + // Simplified global pooling residual block (without full gpool branch for now) + auto residual = input; + + // Activation + auto x = main_module->add_instruction(migraphx::make_op("relu"), input); + + // regularConv + vector wShape = {(size_t)desc->regularConv.outChannels, (size_t)desc->regularConv.inChannels, + (size_t)desc->regularConv.convYSize, (size_t)desc->regularConv.convXSize}; + migraphx::shape wDesc(dataType, wShape); + auto w = main_module->add_literal(migraphx::literal(wDesc, desc->regularConv.weights)); + + int pad = (desc->regularConv.convYSize - 1) / 2; + vector padding = {(size_t)pad, (size_t)pad}; + auto conv_op = migraphx::make_op("convolution", { + {"padding", migraphx::value(padding)}, + {"stride", migraphx::value(vector{1, 1})}, + {"dilation", migraphx::value(vector{(size_t)desc->regularConv.dilationY, (size_t)desc->regularConv.dilationX})}, + {"group", 1} + }); + x = main_module->add_instruction(conv_op, x, w); + + // midActivation + x = main_module->add_instruction(migraphx::make_op("relu"), x); + + // finalConv + vector w2Shape = {(size_t)desc->finalConv.outChannels, (size_t)desc->finalConv.inChannels, + (size_t)desc->finalConv.convYSize, (size_t)desc->finalConv.convXSize}; + migraphx::shape w2Desc(dataType, w2Shape); + auto w2 = main_module->add_literal(migraphx::literal(w2Desc, desc->finalConv.weights)); + + int pad2 = (desc->finalConv.convYSize - 1) / 2; + vector padding2 = {(size_t)pad2, (size_t)pad2}; + auto conv2_op = migraphx::make_op("convolution", { + {"padding", migraphx::value(padding2)}, + {"stride", migraphx::value(vector{1, 1})}, + {"dilation", migraphx::value(vector{(size_t)desc->finalConv.dilationY, (size_t)desc->finalConv.dilationX})}, + {"group", 1} + }); + x = main_module->add_instruction(conv2_op, x, w2); + + // Add residual + auto result = main_module->add_instruction(migraphx::make_op("add"), x, residual); + + main_module->add_return({result}); + + // Compile and run + migraphx::compile_options compile_opts; + compile_opts.offload_copy = true; + auto target = migraphx::make_target("gpu"); + prog.compile(target, compile_opts); + + migraphx::parameter_map params; + params["input"] = migraphx::argument(migraphx::shape(dataType, inputShape), const_cast(inputBuffer.data())); + + auto results = prog.eval(params); + + // Copy output + size_t outputSize = batchSize * numChannels * nnYLen * nnXLen; + outputBuffer.resize(outputSize); + + auto outputArg = results[0]; + vector tempOutput(outputSize); + outputArg.visit([&](auto output) { + for(size_t i = 0; i < outputSize; i++) { + tempOutput[i] = static_cast(output[i]); + } + }); + outputBuffer = tempOutput; + + return true; + } catch(const exception& e) { + cerr << "testEvaluateGlobalPoolingResidualBlock failed: " << e.what() << endl; + return false; + } +} + +} // namespace NeuralNet diff --git a/cpp/neuralnet/rocmbackend.cpp b/cpp/neuralnet/rocmbackend.cpp new file mode 100644 index 000000000..4fb2dea46 --- /dev/null +++ b/cpp/neuralnet/rocmbackend.cpp @@ -0,0 +1,3053 @@ +#ifdef USE_ROCM_BACKEND + +#include "../neuralnet/rocmerrorcheck.h" +#include "../neuralnet/rocmincludes.h" + +#include "../neuralnet/rocmhelpers.h" +#include "../neuralnet/rocmutils.h" +#include "../neuralnet/modelversion.h" +#include "../neuralnet/nninterface.h" +#include "../neuralnet/nninputs.h" +#include "../neuralnet/sgfmetadata.h" +#include "../neuralnet/nneval.h" +#include "../neuralnet/desc.h" + +#include "../core/simpleallocator.h" +#include "../core/test.h" + +#include "../external/half-2.2.0/include/half.hpp" + +//------------------------ +#include "../core/using.h" +//------------------------ + +using half_t = half_float::half; + +//Define this to print out some of the intermediate values of the neural net +//#define DEBUG_INTERMEDIATE_VALUES + +void NeuralNet::globalInitialize() { + //Empty for cudnn backend +} + +void NeuralNet::globalCleanup() { + hipDeviceReset(); +} + +struct CudaHandles { + hipblasHandle_t cublas; + miopenHandle_t cudnn; + const int majorComputeCapability; + const int minorComputeCapability; + + CudaHandles(int major, int minor) + : majorComputeCapability(major), + minorComputeCapability(minor) + { + CUBLAS_ERR("CudaHandles",hipblasCreate(&cublas)); + CUDNN_ERR("CudaHandles",miopenCreate(&cudnn)); + } + + ~CudaHandles() { + hipblasDestroy(cublas); + miopenDestroy(cudnn); + } + + static CudaHandles* cudaHandlesTesting() { + const int gpuIdxForThisThread = 0; + hipDeviceProp_t prop; + hipGetDeviceProperties(&prop,gpuIdxForThisThread); + return new CudaHandles(prop.major, prop.minor); + } + + CudaHandles(const CudaHandles&) = delete; + CudaHandles& operator=(const CudaHandles&) = delete; +}; + +//--------------------------------------------------------------------------------- + +template +struct ByBatchSize { + const int maxBatchSize; + T* data; + miopenStatus_t (*destroyFunc)(T); + + ByBatchSize() + : maxBatchSize(0), data(nullptr), destroyFunc(nullptr) + {} + + ByBatchSize( + int maxBatchSize_ + ) : maxBatchSize(maxBatchSize_), data(nullptr), destroyFunc(nullptr) { + data = new T[maxBatchSize]; + } + + ByBatchSize(const ByBatchSize&) = delete; + ByBatchSize& operator=(const ByBatchSize&) = delete; + + ~ByBatchSize() { + if(destroyFunc != nullptr && data != nullptr) { + for(int batchSize = 1; batchSize <= maxBatchSize; batchSize++) { + (*destroyFunc)(data[batchSize-1]); + } + } + if(data != nullptr) { + delete[] data; + data = nullptr; + } + } + T& operator[](int batchSize) { + return data[batchSize-1]; + } + const T& operator[](int batchSize) const { + return data[batchSize-1]; + } +}; + +template +struct ByBatchSizeView { + int maxBatchSize; + T* data; + + ByBatchSizeView() + : maxBatchSize(0), data(nullptr) + {} + + ByBatchSizeView(const ByBatchSize& toView) + : maxBatchSize(toView.maxBatchSize), data(toView.data) + {} + ByBatchSizeView& operator=(const ByBatchSize& toView) { + maxBatchSize = toView.maxBatchSize; + data = toView.data; + } + + ~ByBatchSizeView() { + } + T& operator[](int batchSize) { + return data[batchSize-1]; + } + const T& operator[](int batchSize) const { + return data[batchSize-1]; + } +}; + +//--------------------------------------------------------------------------------- + + +//channels, useFP16, useNHWC +typedef std::tuple CudnnTensorDesc4DKey; + +struct CudnnManager { + const string name; + const int maxBatchSize; + const int nnXLen; + const int nnYLen; + std::map*> tensorDesc4DByBatchSizeByKey; + + CudnnManager(string name_, int maxBatchSize_, int nnXLen_, int nnYLen_) + :name(name_), + maxBatchSize(maxBatchSize_), + nnXLen(nnXLen_), + nnYLen(nnYLen_), + tensorDesc4DByBatchSizeByKey() + { + } + + ~CudnnManager() { + for(auto& iter: tensorDesc4DByBatchSizeByKey) { + delete iter.second; + } + } + + ByBatchSizeView getTensorDesc4DByBatchSize( + int channels, bool useFP16, bool useNHWC + ) { + auto iter = tensorDesc4DByBatchSizeByKey.find({channels, useFP16, useNHWC}); + if(iter != tensorDesc4DByBatchSizeByKey.end()) { + return ByBatchSizeView(*(iter->second)); + } + ByBatchSize* descs = new ByBatchSize(maxBatchSize); + for(int batchSize = 1; batchSize <= maxBatchSize; batchSize++) { + miopenTensorDescriptor_t& desc = (*descs)[batchSize]; + CUDNN_ERR(name.c_str(),miopenCreateTensorDescriptor(&desc)); + CUDNN_ERR(name.c_str(),miopenSet4dTensorDescriptor( + desc, + (useFP16 ? miopenHalf : miopenFloat), + batchSize, + channels, + nnYLen, + nnXLen + )); + } + descs->destroyFunc = miopenDestroyTensorDescriptor; + tensorDesc4DByBatchSizeByKey[{channels, useFP16, useNHWC}] = descs; + return ByBatchSizeView(*descs); + } +}; + +//--------------------------------------------------------------------------------- + +struct ScratchBuffers { + + const size_t batchXYFloatBytes; + const size_t batchFloatBytes; + const size_t batchXYBytes; + const size_t batchBytes; + + SimpleAllocator* allocator; + + // Not scratch, but convenient to have here + void* zeroBuf; + void* oneBuf; + + ScratchBuffers() = delete; + ScratchBuffers(const ScratchBuffers&) = delete; + ScratchBuffers& operator=(const ScratchBuffers&) = delete; + + ScratchBuffers(int maxBatchSize, int nnXLen, int nnYLen, bool useFP16) + : batchXYFloatBytes((size_t)maxBatchSize * nnXLen * nnYLen * sizeof(float)), + batchFloatBytes((size_t)maxBatchSize * sizeof(float)), + batchXYBytes((size_t)maxBatchSize * nnXLen * nnYLen * (useFP16 ? sizeof(half_t) : sizeof(float))), + batchBytes((size_t)maxBatchSize * (useFP16 ? sizeof(half_t) : sizeof(float))) + { + std::function allocateFunc = [](size_t size) { + void* buf; + CUDA_ERR("ScratchBuffers",hipMalloc(&buf, size)); + return buf; + }; + std::function releaseFunc = [](void* buf) { + hipFree(buf); + }; + + allocator = new SimpleAllocator(allocateFunc, releaseFunc); + + CudaUtils::hostMallocZeroOneBufs(zeroBuf, oneBuf, useFP16); + } + ~ScratchBuffers() { + delete allocator; + free(zeroBuf); + free(oneBuf); + } + + size_t getBufSizeXY(int channels) const { + return channels * batchXYBytes; + } + size_t getBufSizeXYFloat(int channels) const { + return channels * batchXYFloatBytes; + } + size_t getBufSizeFloat(int channels) const { + return channels * batchFloatBytes; + } + size_t getBufSize(int channels) const { + return channels * batchBytes; + } + +}; + + +//--------------------------------------------------------------------------------- + +struct ConvLayer { + const string name; + const int inChannels; + const int outChannels; + const int nnXLen; + const int nnYLen; + const int maxBatchSize; + const bool usingFP16; + ByBatchSizeView inputDescriptors; + ByBatchSizeView outputDescriptors; + miopenTensorDescriptor_t filterDescriptor; + miopenConvolutionDescriptor_t convolutionDescriptor; + ByBatchSize* convolutionAlgorithms; //array of one for each batch size + void* filterBuf; + void* accumBuf; // Pre-allocated buffer for residual accumulation (miopenConvolutionForwardImmediate has no beta) + + ConvLayer() = delete; + ConvLayer(const ConvLayer&) = delete; + ConvLayer& operator=(const ConvLayer&) = delete; + + ConvLayer( + CudaHandles* cudaHandles, + CudnnManager* manager, + const ConvLayerDesc* desc, + bool useFP16, + bool useNHWC + ) : ConvLayer(cudaHandles, manager, desc, useFP16, useNHWC, useNHWC) + {} + + ConvLayer( + CudaHandles* cudaHandles, + CudnnManager* manager, + const ConvLayerDesc* desc, + bool useFP16, + bool useNHWCIn, + bool useNHWCOut + ) : + name(desc->name), + inChannels(desc->inChannels), + outChannels(desc->outChannels), + nnXLen(manager->nnXLen), + nnYLen(manager->nnYLen), + maxBatchSize(manager->maxBatchSize), + usingFP16(useFP16) + { + int convYSize = desc->convYSize; + int convXSize = desc->convXSize; + int dilationY = desc->dilationY; + int dilationX = desc->dilationX; + int paddingX = (convXSize / 2) * dilationX; + int paddingY = (convYSize / 2) * dilationY; + + assert(convXSize % 2 == 1); + assert(convYSize % 2 == 1); + + inputDescriptors = manager->getTensorDesc4DByBatchSize(inChannels,useFP16,useNHWCIn); + outputDescriptors = manager->getTensorDesc4DByBatchSize(outChannels,useFP16,useNHWCOut); + int maxBatchSize = manager->maxBatchSize; + + bool filterNHWC = useNHWCOut && dilationY == 1 && dilationX == 1; + + CUDNN_ERR(name.c_str(),miopenCreateTensorDescriptor(&filterDescriptor)); + CUDNN_ERR(name.c_str(),miopenSet4dTensorDescriptor( + filterDescriptor, + (useFP16 ? miopenHalf : miopenFloat), + outChannels, + inChannels, + convYSize, + convXSize + )); + + int yStride = 1; + int xStride = 1; + + + CUDNN_ERR(name.c_str(),miopenCreateConvolutionDescriptor(&convolutionDescriptor)); + CUDNN_ERR(name.c_str(),miopenInitConvolutionDescriptor( + convolutionDescriptor, + miopenConvolution, + paddingY, + paddingX, + yStride, + xStride, + dilationY, + dilationX + )); + if(useFP16) { + int alt = 1; // non‑zero enables alt‑impl on MI2xx+ GPUs + CUDNN_ERR(name.c_str(),miopenSetConvolutionAttribute(convolutionDescriptor,MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL,alt)); + } + + convolutionAlgorithms = new ByBatchSize(maxBatchSize); + + for(int batchSize = 1; batchSize <= maxBatchSize; batchSize++) { + const miopenTensorDescriptor_t& inputDescriptor = inputDescriptors[batchSize]; + const miopenTensorDescriptor_t& outputDescriptor = outputDescriptors[batchSize]; + size_t availableAlgoCount = 0; + CUDNN_ERR(name.c_str(),miopenConvolutionForwardGetSolutionCount( + cudaHandles->cudnn, + filterDescriptor, + inputDescriptor, + convolutionDescriptor, + outputDescriptor, + &availableAlgoCount + )); + if(availableAlgoCount <= 0) + throw StringError("miopenConvolutionForwardGetSolutionCount returned 0 algorithms?"); + std::vector solutions(availableAlgoCount); + size_t returnedAlgoCount = 0; + CUDNN_ERR(name.c_str(),miopenConvolutionForwardGetSolution( + cudaHandles->cudnn, + filterDescriptor, + inputDescriptor, + convolutionDescriptor, + outputDescriptor, + availableAlgoCount, + &returnedAlgoCount, + solutions.data() + )); + if(returnedAlgoCount <= 0) + throw StringError("miopenConvolutionForwardGetSolution returned no algorithms?"); + (*convolutionAlgorithms)[batchSize] = solutions[0]; + CUDNN_ERR(name.c_str(),miopenConvolutionForwardCompileSolution( + cudaHandles->cudnn, + filterDescriptor, + inputDescriptor, + convolutionDescriptor, + outputDescriptor, + (*convolutionAlgorithms)[batchSize].solution_id + )); + } + + assert(desc->weights.size() == convYSize * convXSize * inChannels * outChannels); + + if(filterNHWC) { + vector weightsTransposed(desc->weights.size()); + for(int y = 0; y < convYSize; y++) { + for(int x = 0; x < convXSize; x++) { + for(int ic = 0; ic < inChannels; ic++) { + for(int oc = 0; oc < outChannels; oc++) { + weightsTransposed[((oc*convYSize + y)*convXSize + x)*inChannels + ic] = + desc->weights[((oc*inChannels + ic)*convYSize + y)*convXSize + x]; + } + } + } + } + CudaUtils::mallocAndCopyToDevice(name,weightsTransposed,filterBuf,useFP16); + hipDeviceSynchronize(); + } + else + CudaUtils::mallocAndCopyToDevice(name,desc->weights,filterBuf,useFP16); + + // Pre-allocate buffer for accumulate mode (residual skip connections). + // miopenConvolutionForwardImmediate does not support alpha/beta unlike cuDNN, + // so we need to save the output before conv and add it back afterwards. + { + int elemSize = usingFP16 ? sizeof(half_t) : sizeof(float); + size_t accumBytes = (size_t)maxBatchSize * outChannels * nnXLen * nnYLen * elemSize; + CUDA_ERR(name.c_str(), hipMalloc(&accumBuf, accumBytes)); + } + } + + ~ConvLayer() { + hipFree(filterBuf); + hipFree(accumBuf); + miopenDestroyTensorDescriptor(filterDescriptor); + miopenDestroyConvolutionDescriptor(convolutionDescriptor); + delete convolutionAlgorithms; + } + + size_t requiredWorkspaceBytes( + CudaHandles* cudaHandles, + int batchSize + ) const { + size_t workspaceBytes = 0; + CUDNN_ERR(name.c_str(),miopenConvolutionForwardGetSolutionWorkspaceSize( + cudaHandles->cudnn, + filterDescriptor, + inputDescriptors[batchSize], + convolutionDescriptor, + outputDescriptors[batchSize], + (*convolutionAlgorithms)[batchSize].solution_id, + &workspaceBytes + )); + return workspaceBytes; + } + + void apply( + CudaHandles* cudaHandles, + int batchSize, + bool accumulate, + void* inputBuf, + void* outputBuf, + void* workspaceBuf, + size_t workspaceBytes + ) const { + // miopenConvolutionForwardImmediate does NOT support alpha/beta (unlike cuDNN). + // When accumulate=true, we need: outputBuf = conv(inputBuf) + outputBuf (residual skip connection). + // Save outputBuf content to pre-allocated accumBuf, run conv, then add back. + if(accumulate) { + int elemSize = usingFP16 ? sizeof(half) : sizeof(float); + size_t outputBytes = (size_t)batchSize * outChannels * nnXLen * nnYLen * elemSize; + CUDA_ERR(name.c_str(), hipMemcpyAsync(accumBuf, outputBuf, outputBytes, hipMemcpyDeviceToDevice)); + } + + CUDNN_ERR(name.c_str(), miopenConvolutionForwardImmediate( + cudaHandles->cudnn, + filterDescriptor, + filterBuf, + inputDescriptors[batchSize], + inputBuf, + convolutionDescriptor, + outputDescriptors[batchSize], + outputBuf, + workspaceBuf, + workspaceBytes, + (*convolutionAlgorithms)[batchSize].solution_id + )); + + if(accumulate) { + int outputElems = (int)((size_t)batchSize * outChannels * nnXLen * nnYLen); + if(usingFP16) + customCudaAddTensorsInplace((half*)outputBuf, (const half*)accumBuf, outputElems); + else + customCudaAddTensorsInplace((float*)outputBuf, (const float*)accumBuf, outputElems); + } + } + +}; + +//--------------------------------------------------------------------------------- + +struct BatchNormLayer { + const string name; + const int numChannels; + const float epsilon; + const int activation; + const int nnXLen; + const int nnYLen; + + const bool usingFP16; + const bool usingNHWC; + + void* mergedScaleBuf; + void* mergedBiasBuf; + + BatchNormLayer() = delete; + BatchNormLayer(const BatchNormLayer&) = delete; + BatchNormLayer& operator=(const BatchNormLayer&) = delete; + + BatchNormLayer( + CudaHandles* cudaHandles, + const BatchNormLayerDesc* desc, + const ActivationLayerDesc* actDesc, + int nnX, + int nnY, + bool useFP16, + bool useNHWC + ) : + name(desc->name), + numChannels(desc->numChannels), + epsilon(desc->epsilon), + activation(actDesc->activation), + nnXLen(nnX), + nnYLen(nnY), + usingFP16(useFP16), + usingNHWC(useNHWC) + { + (void)cudaHandles; + + assert(desc->mean.size() == numChannels); + assert(desc->variance.size() == numChannels); + assert(desc->scale.size() == numChannels); + assert(desc->bias.size() == numChannels); + assert(desc->mergedScale.size() == numChannels); + assert(desc->mergedBias.size() == numChannels); + CudaUtils::mallocAndCopyToDevice(name,desc->mergedScale,mergedScaleBuf,useFP16); + CudaUtils::mallocAndCopyToDevice(name,desc->mergedBias,mergedBiasBuf,useFP16); + } + ~BatchNormLayer() { + hipFree(mergedScaleBuf); + hipFree(mergedBiasBuf); + } + + void apply( + CudaHandles* cudaHandles, + int batchSize, + void* inputBuf, + const void* maskBuf, //ok to be null + void* outputBuf + ) const { + (void)cudaHandles; + if(!usingFP16) { + if(!usingNHWC) + customCudaApplyCScaleBiasNCHW((const float*)inputBuf,(float*)outputBuf,(const float*)mergedScaleBuf,(const float*)mergedBiasBuf, + (const float*)maskBuf, + batchSize,numChannels,nnXLen*nnYLen,activation); + else + customCudaApplyCScaleBiasNHWC((const float*)inputBuf,(float*)outputBuf,(const float*)mergedScaleBuf,(const float*)mergedBiasBuf, + (const float*)maskBuf, + batchSize,nnXLen*nnYLen,numChannels,activation); + } + else { + if(!usingNHWC) + customCudaApplyCScaleBiasNCHW((const half*)inputBuf,(half*)outputBuf,(const half*)mergedScaleBuf,(const half*)mergedBiasBuf, + (const half*)maskBuf, + batchSize,numChannels,nnXLen*nnYLen,activation); + else + customCudaApplyCScaleBiasNHWC((const half*)inputBuf,(half*)outputBuf,(const half*)mergedScaleBuf,(const half*)mergedBiasBuf, + (const half*)maskBuf, + batchSize,nnXLen*nnYLen,numChannels,activation); + CUDA_ERR(name.c_str(),hipPeekAtLastError()); + } + + } + +}; + + +//--------------------------------------------------------------------------------- + +struct MatMulLayer { + const string name; + const int inChannels; + const int outChannels; + const bool usingFP16; + void* matBuf; + + MatMulLayer() = delete; + MatMulLayer(const MatMulLayer&) = delete; + MatMulLayer& operator=(const MatMulLayer&) = delete; + + MatMulLayer( + CudaHandles* cudaHandles, + const MatMulLayerDesc* desc, + bool useFP16 + ) : + name(desc->name), + inChannels(desc->inChannels), + outChannels(desc->outChannels), + usingFP16(useFP16) + { + (void)cudaHandles; + + if(inChannels > 0 && outChannels > 0) { + assert(desc->weights.size() == inChannels * outChannels); + CudaUtils::mallocAndCopyToDevice(name,desc->weights,matBuf,useFP16); + } + else { + matBuf = NULL; + } + } + + ~MatMulLayer() { + if(inChannels > 0 && outChannels > 0) + hipFree(matBuf); + } + + size_t requiredWorkspaceBytes( + CudaHandles* cudaHandles + ) const { + (void)cudaHandles; + size_t workspaceBytes = 0; + return workspaceBytes; + } + + void apply( + CudaHandles* cudaHandles, + ScratchBuffers* scratch, + int batchSize, + void* inputBuf, + void* outputBuf, + void* workspaceBuf, + size_t workspaceBytes + ) const { + (void)workspaceBuf; + (void)workspaceBytes; + assert(inChannels > 0 && outChannels > 0); + + if(!usingFP16) { + const float alpha = 1.0f; + const float beta = 0.0f; + CUBLAS_ERR(name.c_str(),hipblasSgemm( + cudaHandles->cublas, + HIPBLAS_OP_N, + HIPBLAS_OP_N, + outChannels, + batchSize, + inChannels, + &alpha, + (const float*)matBuf,outChannels, + (const float*)inputBuf,inChannels, + &beta, + (float*)outputBuf,outChannels + )); + } + else { + const hipblasHalf* alpha = (const hipblasHalf*)scratch->oneBuf; + const hipblasHalf* beta = (const hipblasHalf*)scratch->zeroBuf; + CUBLAS_ERR(name.c_str(),hipblasHgemm( + cudaHandles->cublas, + HIPBLAS_OP_N, + HIPBLAS_OP_N, + outChannels, + batchSize, + inChannels, + alpha, + (const hipblasHalf*)matBuf,outChannels, + (const hipblasHalf*)inputBuf,inChannels, + beta, + (hipblasHalf*)outputBuf,outChannels + )); + } + + } + +}; + +//--------------------------------------------------------------------------------- + +struct MatBiasLayer { + const string name; + const int numChannels; + const bool usingFP16; + const int activation; + + void* biasBuf; + + MatBiasLayer() = delete; + MatBiasLayer(const MatBiasLayer&) = delete; + MatBiasLayer& operator=(const MatBiasLayer&) = delete; + + MatBiasLayer( + CudaHandles* cudaHandles, + const MatBiasLayerDesc* desc, + bool useFP16, + int activation_ + ) : + name(desc->name), + numChannels(desc->numChannels), + usingFP16(useFP16), + activation(activation_) + { + (void)cudaHandles; + if(numChannels > 0) { + assert(desc->weights.size() == numChannels); + CudaUtils::mallocAndCopyToDevice(name,desc->weights,biasBuf,useFP16); + } + else + biasBuf = NULL; + } + + ~MatBiasLayer() { + if(numChannels > 0) + hipFree(biasBuf); + } + + void apply( + CudaHandles* cudaHandles, + int batchSize, + void* matBuf + ) const { + (void)cudaHandles; + assert(numChannels > 0); + if(!usingFP16) { + customCudaAddCBiasInplaceNC((float*)matBuf,(const float*)biasBuf,batchSize,numChannels,activation); + CUDA_ERR(name.c_str(),hipPeekAtLastError()); + } + else { + customCudaAddCBiasInplaceNC((half*)matBuf,(const half*)biasBuf,batchSize,numChannels,activation); + CUDA_ERR(name.c_str(),hipPeekAtLastError()); + } + } + +}; + +//--------------------------------------------------------------------------------- + +struct NormActConv { + const BatchNormLayer norm; + const ConvLayer conv; + + const int inChannels; + const int outChannels; + const int nnXLen; + const int nnYLen; + const bool usingFP16; + const bool usingNHWC; + + NormActConv() = delete; + NormActConv(const NormActConv&) = delete; + NormActConv& operator=(const NormActConv&) = delete; + + NormActConv( + CudaHandles* cudaHandles, + CudnnManager* manager, + const BatchNormLayerDesc* normDesc, + const ActivationLayerDesc* actDesc, + const ConvLayerDesc* convDesc, + int nnX, + int nnY, + bool useFP16, + bool useNHWC + ): norm(cudaHandles,normDesc,actDesc,nnX,nnY,useFP16,useNHWC), + conv(cudaHandles,manager,convDesc,useFP16,useNHWC), + inChannels(norm.numChannels), + outChannels(conv.outChannels), + nnXLen(nnX), + nnYLen(nnY), + usingFP16(useFP16), + usingNHWC(useNHWC) + { + assert(norm.numChannels == conv.inChannels); + } + + ~NormActConv() + {} + + size_t requiredWorkspaceBytes( + CudaHandles* cudaHandles, + int batchSize + ) const { + size_t bytes = 0; + size_t b; + b = conv.requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + return bytes; + } + + void apply( + CudaHandles* cudaHandles, + int batchSize, + bool accumulate, + void* inBuf, + void* inScratchBuf, + void* outBuf, + void* maskBuf, + void* workspaceBuf, + size_t workspaceBytes + ) const { + norm.apply(cudaHandles,batchSize,inBuf,maskBuf,inScratchBuf); +#ifdef DEBUG_INTERMEDIATE_VALUES + CudaUtils::debugPrint4D(string("AFTER NORM "), inScratchBuf, batchSize, inChannels, nnXLen, nnYLen, usingNHWC, usingFP16); +#endif + conv.apply(cudaHandles,batchSize,accumulate,inScratchBuf,outBuf,workspaceBuf,workspaceBytes); + } + +}; + + +//--------------------------------------------------------------------------------- + +struct ResidualBlock { + const string name; + const NormActConv normActConv1; + const NormActConv normActConv2; + + ResidualBlock() = delete; + ResidualBlock(const ResidualBlock&) = delete; + ResidualBlock& operator=(const ResidualBlock&) = delete; + + ResidualBlock( + CudaHandles* cudaHandles, + CudnnManager* manager, + const ResidualBlockDesc* desc, + int nnX, + int nnY, + bool useFP16, + bool useNHWC + ): name(desc->name), + normActConv1(cudaHandles,manager,&desc->preBN,&desc->preActivation,&desc->regularConv,nnX,nnY,useFP16,useNHWC), + normActConv2(cudaHandles,manager,&desc->midBN,&desc->midActivation,&desc->finalConv,nnX,nnY,useFP16,useNHWC) + { + } + + ~ResidualBlock() + {} + + size_t requiredWorkspaceBytes( + CudaHandles* cudaHandles, + int batchSize + ) const { + size_t bytes = 0; + size_t b; + b = normActConv1.requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + b = normActConv2.requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + return bytes; + } + + void apply( + CudaHandles* cudaHandles, + ScratchBuffers* scratch, + int batchSize, + void* trunkBuf, + void* trunkScratchBuf, + void* maskBuf, + void* workspaceBuf, + size_t workspaceBytes + ) const { + SizedBuf midIn(scratch->allocator, scratch->getBufSizeXY(normActConv1.outChannels)); + SizedBuf midScratch(scratch->allocator, scratch->getBufSizeXY(normActConv1.outChannels)); + normActConv1.apply(cudaHandles,batchSize,false,trunkBuf,trunkScratchBuf,midIn.buf,maskBuf,workspaceBuf,workspaceBytes); + normActConv2.apply(cudaHandles,batchSize,true,midIn.buf,midScratch.buf,trunkBuf,maskBuf,workspaceBuf,workspaceBytes); + } + +}; + + +//---------------------------------------------------------------------------- + + +struct GlobalPoolingResidualBlock { + const string name; + const BatchNormLayer preBN; + const ConvLayer regularConv; + const ConvLayer gpoolConv; + const BatchNormLayer gpoolBN; + const MatMulLayer gpoolToBiasMul; + const NormActConv normActConv2; + + const int nnXLen; + const int nnYLen; + const int regularChannels; + const int gpoolChannels; + const bool usingFP16; + const bool usingNHWC; + + GlobalPoolingResidualBlock() = delete; + GlobalPoolingResidualBlock(const GlobalPoolingResidualBlock&) = delete; + GlobalPoolingResidualBlock& operator=(const GlobalPoolingResidualBlock&) = delete; + + GlobalPoolingResidualBlock( + CudaHandles* cudaHandles, + CudnnManager* manager, + const GlobalPoolingResidualBlockDesc* desc, + int nnX, + int nnY, + bool useFP16, + bool useNHWC + ): name(desc->name), + preBN(cudaHandles,&desc->preBN,&desc->preActivation,nnX,nnY,useFP16,useNHWC), + regularConv(cudaHandles,manager,&desc->regularConv,useFP16,useNHWC), + gpoolConv(cudaHandles,manager,&desc->gpoolConv,useFP16,useNHWC), + gpoolBN(cudaHandles,&desc->gpoolBN,&desc->gpoolActivation,nnX,nnY,useFP16,useNHWC), + gpoolToBiasMul(cudaHandles,&desc->gpoolToBiasMul,useFP16), + normActConv2(cudaHandles,manager,&desc->midBN,&desc->midActivation,&desc->finalConv,nnX,nnY,useFP16,useNHWC), + nnXLen(nnX), + nnYLen(nnY), + regularChannels(desc->regularConv.outChannels), + gpoolChannels(desc->gpoolConv.outChannels), + usingFP16(useFP16), + usingNHWC(useNHWC) + { + } + + ~GlobalPoolingResidualBlock() { + } + + size_t requiredWorkspaceBytes( + CudaHandles* cudaHandles, + int batchSize + ) const { + size_t bytes = 0; + size_t b; + b = regularConv.requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + b = gpoolConv.requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + b = gpoolToBiasMul.requiredWorkspaceBytes(cudaHandles); + bytes = std::max(bytes,b); + b = normActConv2.requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + b = sizeof(float)*batchSize*gpoolChannels*nnXLen*nnYLen; + bytes = std::max(bytes,b); + return bytes; + } + + void apply( + CudaHandles* cudaHandles, + ScratchBuffers* scratch, + int batchSize, + void* trunkBuf, + void* trunkScratchBuf, + void* maskBuf, + float* maskSumBuf, + void* workspaceBuf, + size_t workspaceBytes + ) const { + SizedBuf regularOut(scratch->allocator, scratch->getBufSizeXY(regularChannels)); + SizedBuf regularScratch(scratch->allocator, scratch->getBufSizeXY(regularChannels)); + SizedBuf gpoolOut(scratch->allocator, scratch->getBufSizeXY(gpoolChannels)); + SizedBuf gpoolOut2(scratch->allocator, scratch->getBufSizeXY(gpoolChannels)); + SizedBuf gpoolConcat(scratch->allocator, scratch->getBufSize(gpoolChannels*3)); + SizedBuf gpoolBias(scratch->allocator, scratch->getBufSize(regularChannels)); + + preBN.apply(cudaHandles,batchSize,trunkBuf,maskBuf,trunkScratchBuf); + regularConv.apply(cudaHandles,batchSize,false,trunkScratchBuf,regularOut.buf,workspaceBuf,workspaceBytes); + gpoolConv.apply(cudaHandles,batchSize,false,trunkScratchBuf,gpoolOut.buf,workspaceBuf,workspaceBytes); + gpoolBN.apply(cudaHandles,batchSize,gpoolOut.buf,maskBuf,gpoolOut2.buf); + + if(!usingFP16) { + if(!usingNHWC) + customCudaPoolRowsGPoolNCHW((const float*)gpoolOut2.buf,(float*)gpoolConcat.buf,batchSize,gpoolChannels,nnXLen*nnYLen,(const float*)maskBuf,maskSumBuf); + else + customCudaPoolRowsGPoolNHWC((const float*)gpoolOut2.buf,(float*)gpoolConcat.buf,batchSize,nnXLen*nnYLen,gpoolChannels,(const float*)maskBuf,maskSumBuf); + } + else { + if(!usingNHWC) + customCudaPoolRowsGPoolNCHW((const half*)gpoolOut2.buf,(half*)gpoolConcat.buf,batchSize,gpoolChannels,nnXLen*nnYLen,(const half*)maskBuf,maskSumBuf); + else + customCudaPoolRowsGPoolNHWC((const half*)gpoolOut2.buf,(half*)gpoolConcat.buf,batchSize,nnXLen*nnYLen,gpoolChannels,(const half*)maskBuf,maskSumBuf); + } + CUDA_ERR(name.c_str(),hipPeekAtLastError()); + + gpoolToBiasMul.apply(cudaHandles,scratch,batchSize,gpoolConcat.buf,gpoolBias.buf,workspaceBuf,workspaceBytes); + + if(!usingFP16) { + if(!usingNHWC) + customCudaAddNCBiasInplaceNCHW((float*)regularOut.buf,(const float*)gpoolBias.buf,batchSize,regularChannels,nnXLen*nnYLen); + else + customCudaAddNCBiasInplaceNHWC((float*)regularOut.buf,(const float*)gpoolBias.buf,batchSize,nnXLen*nnYLen,regularChannels); + } + else { + if(!usingNHWC) + customCudaAddNCBiasInplaceNCHW((half*)regularOut.buf,(const half*)gpoolBias.buf,batchSize,regularChannels,nnXLen*nnYLen); + else + customCudaAddNCBiasInplaceNHWC((half*)regularOut.buf,(const half*)gpoolBias.buf,batchSize,nnXLen*nnYLen,regularChannels); + } + CUDA_ERR(name.c_str(),hipPeekAtLastError()); + + normActConv2.apply(cudaHandles,batchSize,true,regularOut.buf,regularScratch.buf,trunkBuf,maskBuf,workspaceBuf,workspaceBytes); + } + +}; + +//------------------------------------------------------------------------------ + +struct BlockStack { + const int numBlocks; + const int trunkNumChannels; + const int nnXLen; + const int nnYLen; + const bool usingFP16; + const bool usingNHWC; + vector> blocks; + + BlockStack() = delete; + BlockStack(const BlockStack&) = delete; + BlockStack& operator=(const BlockStack&) = delete; + + BlockStack( + CudaHandles* cudaHandles, + CudnnManager* manager, + int nBlocks, + int trunkChannels, + const std::vector>& descBlocks, + int nnX, + int nnY, + bool useFP16, + bool useNHWC + ); + ~BlockStack(); + + size_t requiredWorkspaceBytes( + CudaHandles* cudaHandles, + int batchSize + ) const; + + void apply( + CudaHandles* cudaHandles, + ScratchBuffers* scratch, + int batchSize, + void* maskBuf, + float* maskSumBuf, + void* trunkBuf, + void* trunkScratchBuf, + void* workspaceBuf, + size_t workspaceBytes + ) const; + +}; + +//------------------------------------------------------------------------------ + +struct NestedBottleneckResidualBlock { + const string name; + const NormActConv normActConv1; + const BlockStack blocks; + const NormActConv normActConv2; + + NestedBottleneckResidualBlock() = delete; + NestedBottleneckResidualBlock(const NestedBottleneckResidualBlock&) = delete; + NestedBottleneckResidualBlock& operator=(const NestedBottleneckResidualBlock&) = delete; + + NestedBottleneckResidualBlock( + CudaHandles* cudaHandles, + CudnnManager* manager, + const NestedBottleneckResidualBlockDesc* desc, + int nnX, + int nnY, + bool useFP16, + bool useNHWC + ): name(desc->name), + normActConv1(cudaHandles,manager,&desc->preBN,&desc->preActivation,&desc->preConv,nnX,nnY,useFP16,useNHWC), + blocks(cudaHandles,manager,desc->numBlocks,desc->preConv.outChannels,desc->blocks,nnX,nnY,useFP16,useNHWC), + normActConv2(cudaHandles,manager,&desc->postBN,&desc->postActivation,&desc->postConv,nnX,nnY,useFP16,useNHWC) + { + } + + ~NestedBottleneckResidualBlock() + {} + + size_t requiredWorkspaceBytes( + CudaHandles* cudaHandles, + int batchSize + ) const { + size_t bytes = 0; + size_t b; + b = normActConv1.requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + b = blocks.requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + b = normActConv2.requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + return bytes; + } + + void apply( + CudaHandles* cudaHandles, + ScratchBuffers* scratch, + int batchSize, + void* trunkBuf, + void* trunkScratchBuf, + void* maskBuf, + float* maskSumBuf, + void* workspaceBuf, + size_t workspaceBytes + ) const { + SizedBuf mid(scratch->allocator, scratch->getBufSizeXY(normActConv1.outChannels)); + SizedBuf midScratch(scratch->allocator, scratch->getBufSizeXY(normActConv1.outChannels)); + assert(normActConv1.outChannels == normActConv2.inChannels); + normActConv1.apply(cudaHandles,batchSize,false,trunkBuf,trunkScratchBuf,mid.buf,maskBuf,workspaceBuf,workspaceBytes); + blocks.apply( + cudaHandles, + scratch, + batchSize, + maskBuf, + maskSumBuf, + mid.buf, + midScratch.buf, + workspaceBuf, + workspaceBytes + ); + normActConv2.apply(cudaHandles,batchSize,true,mid.buf,midScratch.buf,trunkBuf,maskBuf,workspaceBuf,workspaceBytes); + } + +}; + +//------------------------------------------------------------------------------ + +BlockStack::BlockStack( + CudaHandles* cudaHandles, + CudnnManager* manager, + int nBlocks, + int trunkChannels, + const std::vector>& descBlocks, + int nnX, + int nnY, + bool useFP16, + bool useNHWC +) : + numBlocks(nBlocks), + trunkNumChannels(trunkChannels), + nnXLen(nnX), + nnYLen(nnY), + usingFP16(useFP16), + usingNHWC(useNHWC) +{ + assert(numBlocks == descBlocks.size()); + for(int i = 0; irequiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + } + else if(blocks[i].first == GLOBAL_POOLING_BLOCK_KIND) { + GlobalPoolingResidualBlock* block = (GlobalPoolingResidualBlock*)blocks[i].second.get(); + b = block->requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + } + else if(blocks[i].first == NESTED_BOTTLENECK_BLOCK_KIND) { + NestedBottleneckResidualBlock* block = (NestedBottleneckResidualBlock*)blocks[i].second.get(); + b = block->requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + } + else { + ASSERT_UNREACHABLE; + } + } + return bytes; +} + +void BlockStack::apply( + CudaHandles* cudaHandles, + ScratchBuffers* scratch, + int batchSize, + void* maskBuf, + float* maskSumBuf, + void* trunkBuf, + void* trunkScratchBuf, + void* workspaceBuf, + size_t workspaceBytes +) const { + + for(int i = 0; iapply( + cudaHandles, + scratch, + batchSize, + trunkBuf, + trunkScratchBuf, + maskBuf, + workspaceBuf, + workspaceBytes + ); + } + else if(blocks[i].first == GLOBAL_POOLING_BLOCK_KIND) { + GlobalPoolingResidualBlock* block = (GlobalPoolingResidualBlock*)blocks[i].second.get(); + block->apply( + cudaHandles, + scratch, + batchSize, + trunkBuf, + trunkScratchBuf, + maskBuf, + maskSumBuf, + workspaceBuf, + workspaceBytes + ); + } + else if(blocks[i].first == NESTED_BOTTLENECK_BLOCK_KIND) { + NestedBottleneckResidualBlock* block = (NestedBottleneckResidualBlock*)blocks[i].second.get(); + block->apply( + cudaHandles, + scratch, + batchSize, + trunkBuf, + trunkScratchBuf, + maskBuf, + maskSumBuf, + workspaceBuf, + workspaceBytes + ); + } + else { + ASSERT_UNREACHABLE; + } + } +} +//------------------------------------------------------------------------------ + +struct SGFMetadataEncoder { + const string name; + + const bool usingFP16; + + const MatMulLayer mul1; + const MatBiasLayer bias1; + const MatMulLayer mul2; + const MatBiasLayer bias2; + const MatMulLayer mul3; + + SGFMetadataEncoder() = delete; + SGFMetadataEncoder(const SGFMetadataEncoder&) = delete; + SGFMetadataEncoder& operator=(const SGFMetadataEncoder&) = delete; + + SGFMetadataEncoder( + CudaHandles* cudaHandles, + const SGFMetadataEncoderDesc* desc, + bool useFP16 + ) : + name(desc->name), + usingFP16(useFP16), + mul1(cudaHandles,&desc->mul1,useFP16), + bias1(cudaHandles,&desc->bias1,useFP16,desc->act1.activation), + mul2(cudaHandles,&desc->mul2,useFP16), + bias2(cudaHandles,&desc->bias2,useFP16,desc->act2.activation), + mul3(cudaHandles,&desc->mul3,useFP16) + { + } + + ~SGFMetadataEncoder() + { + } + + size_t requiredWorkspaceBytes( + CudaHandles* cudaHandles, + int batchSize + ) const { + (void)batchSize; + size_t bytes = 0; + size_t b; + + b = mul1.requiredWorkspaceBytes(cudaHandles); + bytes = std::max(bytes,b); + b = mul2.requiredWorkspaceBytes(cudaHandles); + bytes = std::max(bytes,b); + b = mul3.requiredWorkspaceBytes(cudaHandles); + bytes = std::max(bytes,b); + + return bytes; + } + + void apply( + CudaHandles* cudaHandles, + ScratchBuffers* scratch, + int batchSize, + void* inputBuf, + void* outputBuf, + void* workspaceBuf, + size_t workspaceBytes + ) const { + SizedBuf internalBuf1(scratch->allocator, scratch->getBufSizeFloat(std::max(mul1.outChannels,mul2.outChannels))); + SizedBuf internalBuf2(scratch->allocator, scratch->getBufSizeFloat(std::max(mul1.outChannels,mul2.outChannels))); + + mul1.apply(cudaHandles,scratch,batchSize,inputBuf,internalBuf1.buf,workspaceBuf,workspaceBytes); + bias1.apply(cudaHandles,batchSize,internalBuf1.buf); + mul2.apply(cudaHandles,scratch,batchSize,internalBuf1.buf,internalBuf2.buf,workspaceBuf,workspaceBytes); + bias2.apply(cudaHandles,batchSize,internalBuf2.buf); + mul3.apply(cudaHandles,scratch,batchSize,internalBuf2.buf,outputBuf,workspaceBuf,workspaceBytes); + } + +}; + + +//---------------------------------------------------------------------------- + +struct Trunk { + const string name; + const int modelVersion; + const int numBlocks; + const int trunkNumChannels; + + const int nnXLen; + const int nnYLen; + const bool usingFP16; + const bool usingNHWC; + + std::unique_ptr initialConv; + std::unique_ptr initialMatMul; + std::unique_ptr sgfMetadataEncoder; + const BlockStack blocks; + std::unique_ptr trunkTipBN; + + Trunk() = delete; + Trunk(const Trunk&) = delete; + Trunk& operator=(const Trunk&) = delete; + + Trunk( + CudaHandles* cudaHandles, + CudnnManager* manager, + const TrunkDesc* desc, + int nnX, + int nnY, + bool inputsUseNHWC, + bool useFP16, + bool useNHWC + ) : + name(desc->name), + modelVersion(desc->modelVersion), + numBlocks(desc->numBlocks), + trunkNumChannels(desc->trunkNumChannels), + nnXLen(nnX), + nnYLen(nnY), + usingFP16(useFP16), + usingNHWC(useNHWC), + blocks(cudaHandles,manager,desc->numBlocks,desc->trunkNumChannels,desc->blocks,nnX,nnY,useFP16,useNHWC) + { + int midNumChannels = desc->midNumChannels; + int regularNumChannels = desc->regularNumChannels; + int gpoolNumChannels = desc->gpoolNumChannels; + + int maxBatchSize = manager->maxBatchSize; + CudaUtils::checkBufferSize(maxBatchSize,nnXLen,nnYLen,trunkNumChannels); + CudaUtils::checkBufferSize(maxBatchSize,nnXLen,nnYLen,midNumChannels); + CudaUtils::checkBufferSize(maxBatchSize,nnXLen,nnYLen,regularNumChannels); + CudaUtils::checkBufferSize(maxBatchSize,nnXLen,nnYLen,gpoolNumChannels); + + initialConv = std::make_unique(cudaHandles,manager,&desc->initialConv,useFP16,inputsUseNHWC,useNHWC); + initialMatMul = std::make_unique(cudaHandles,&desc->initialMatMul,useFP16); + if(desc->metaEncoderVersion > 0) { + sgfMetadataEncoder = std::make_unique(cudaHandles,&desc->sgfMetadataEncoder,useFP16); + testAssert(sgfMetadataEncoder->mul3.outChannels == initialMatMul->outChannels); + } + + trunkTipBN = std::make_unique(cudaHandles,&desc->trunkTipBN,&desc->trunkTipActivation,nnXLen,nnYLen,useFP16,useNHWC); + assert(desc->blocks.size() == numBlocks); + } + + ~Trunk() + { + } + + size_t requiredWorkspaceBytes( + CudaHandles* cudaHandles, + int batchSize + ) const { + size_t bytes = 0; + size_t b; + + b = initialConv->requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + + b = initialMatMul->requiredWorkspaceBytes(cudaHandles); + bytes = std::max(bytes,b); + + if(sgfMetadataEncoder != nullptr) { + b = sgfMetadataEncoder->requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + } + + b = blocks.requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + return bytes; + } + + void apply( + CudaHandles* cudaHandles, + ScratchBuffers* scratch, + int batchSize, + void* inputBuf, + void* inputGlobalBuf, + void* inputMetaBuf, + void* maskBuf, + float* maskSumBuf, + void* trunkBuf, + void* workspaceBuf, + size_t workspaceBytes + ) const { + + SizedBuf trunkScratch(scratch->allocator, scratch->getBufSizeXY(trunkNumChannels)); + + //Feed the conv into trunkScratch.buf, not trunkBuf + initialConv->apply(cudaHandles,batchSize,false,inputBuf,trunkScratch.buf,workspaceBuf,workspaceBytes); + + #ifdef DEBUG_INTERMEDIATE_VALUES + CudaUtils::debugPrint4D(string("After initial conv"), trunkScratch.buf, batchSize, trunkNumChannels, nnXLen, nnYLen, usingNHWC, usingFP16); + #endif + + //Feed the matmul into trunkBuf + initialMatMul->apply(cudaHandles,scratch,batchSize,inputGlobalBuf,trunkBuf,workspaceBuf,workspaceBytes); + //Then accumulate it into trunkScratch.buf, broadcasting during the process + if(!usingFP16) { + if(!usingNHWC) + customCudaAddNCBiasInplaceNCHW((float*)trunkScratch.buf,(const float*)trunkBuf,batchSize,trunkNumChannels,nnXLen*nnYLen); + else + customCudaAddNCBiasInplaceNHWC((float*)trunkScratch.buf,(const float*)trunkBuf,batchSize,nnXLen*nnYLen,trunkNumChannels); + } + else { + if(!usingNHWC) + customCudaAddNCBiasInplaceNCHW((half*)trunkScratch.buf,(const half*)trunkBuf,batchSize,trunkNumChannels,nnXLen*nnYLen); + else + customCudaAddNCBiasInplaceNHWC((half*)trunkScratch.buf,(const half*)trunkBuf,batchSize,nnXLen*nnYLen,trunkNumChannels); + } + CUDA_ERR(name.c_str(),hipPeekAtLastError()); + + if(sgfMetadataEncoder != nullptr) { + testAssert(inputMetaBuf != NULL); + //Feed the result into trunkBuf + sgfMetadataEncoder->apply(cudaHandles,scratch,batchSize,inputMetaBuf,trunkBuf,workspaceBuf,workspaceBytes); + //Then accumulate it into trunkScratch.buf, broadcasting during the process + if(!usingFP16) { + if(!usingNHWC) + customCudaAddNCBiasInplaceNCHW((float*)trunkScratch.buf,(const float*)trunkBuf,batchSize,trunkNumChannels,nnXLen*nnYLen); + else + customCudaAddNCBiasInplaceNHWC((float*)trunkScratch.buf,(const float*)trunkBuf,batchSize,nnXLen*nnYLen,trunkNumChannels); + } + else { + if(!usingNHWC) + customCudaAddNCBiasInplaceNCHW((half*)trunkScratch.buf,(const half*)trunkBuf,batchSize,trunkNumChannels,nnXLen*nnYLen); + else + customCudaAddNCBiasInplaceNHWC((half*)trunkScratch.buf,(const half*)trunkBuf,batchSize,nnXLen*nnYLen,trunkNumChannels); + } + CUDA_ERR(name.c_str(),hipPeekAtLastError()); + } + else { + testAssert(inputMetaBuf == NULL); + } + + //Flip trunkBuf and trunkScratch.buf so that the result gets accumulated in trunkScratch.buf + blocks.apply( + cudaHandles, + scratch, + batchSize, + maskBuf, + maskSumBuf, + trunkScratch.buf, + trunkBuf, + workspaceBuf, + workspaceBytes + ); + + //And now with the final BN port it from trunkScratch.buf to trunkBuf. + trunkTipBN->apply(cudaHandles,batchSize,trunkScratch.buf,maskBuf,trunkBuf); + + #ifdef DEBUG_INTERMEDIATE_VALUES + CudaUtils::debugPrint4D(string("Trunk tip"), trunkBuf, batchSize, trunkNumChannels, nnXLen, nnYLen, usingNHWC, usingFP16); + #endif + } + +}; + +//------------------------------------------------------------------------------ + +static void fillMaskFloatBufAndMaskSumBuf(void* maskBuf, float*& maskFloatBuf, float*& maskSumBuf, bool usingFP16, int batchSize, int nnXLen, int nnYLen) { + if(!usingFP16) { + maskFloatBuf = (float*)maskBuf; + customCudaPoolRowsSumNCHW((const float*)maskFloatBuf,maskSumBuf,batchSize,1,nnXLen*nnYLen,1.0); + CUDA_ERR("sumMask",hipPeekAtLastError()); + } + else { + customCudaCopyFromHalf((const half*)maskBuf,maskFloatBuf,batchSize*nnXLen*nnYLen); + CUDA_ERR("copyMaskFromHalf",hipPeekAtLastError()); + customCudaPoolRowsSumNCHW((const float*)maskFloatBuf,maskSumBuf,batchSize,1,nnXLen*nnYLen,1.0); + CUDA_ERR("sumMask",hipPeekAtLastError()); + } +} + + +//------------------------------------------------------------------------------ + +struct PolicyHead { + const string name; + const int modelVersion; + const int nnXLen; + const int nnYLen; + const int p1Channels; + const int g1Channels; + const int p2Channels; + const bool usingFP16; + const bool usingNHWC; + + const ConvLayer p1Conv; + const ConvLayer g1Conv; + const BatchNormLayer g1BN; + const MatMulLayer gpoolToBiasMul; + const BatchNormLayer p1BN; + const ConvLayer p2Conv; + const MatMulLayer gpoolToPassMul; + const MatBiasLayer gpoolToPassBias; + const MatMulLayer gpoolToPassMul2; + + PolicyHead() = delete; + PolicyHead(const PolicyHead&) = delete; + PolicyHead& operator=(const PolicyHead&) = delete; + + PolicyHead( + CudaHandles* cudaHandles, + CudnnManager* manager, + const PolicyHeadDesc* desc, + int nnX, + int nnY, + bool useFP16, + bool useNHWC + ) : + name(desc->name), + modelVersion(desc->modelVersion), + nnXLen(nnX), + nnYLen(nnY), + p1Channels(desc->p1Conv.outChannels), + g1Channels(desc->g1Conv.outChannels), + p2Channels(desc->p2Conv.outChannels), + usingFP16(useFP16), + usingNHWC(useNHWC), + p1Conv(cudaHandles,manager,&desc->p1Conv,useFP16,useNHWC), + g1Conv(cudaHandles,manager,&desc->g1Conv,useFP16,useNHWC), + g1BN(cudaHandles,&desc->g1BN,&desc->g1Activation,nnX,nnY,useFP16,useNHWC), + gpoolToBiasMul(cudaHandles,&desc->gpoolToBiasMul,false), + p1BN(cudaHandles,&desc->p1BN,&desc->p1Activation,nnX,nnY,false,useNHWC), + p2Conv(cudaHandles,manager,&desc->p2Conv,false,useNHWC), + gpoolToPassMul(cudaHandles,&desc->gpoolToPassMul,false), + gpoolToPassBias(cudaHandles,&desc->gpoolToPassBias,false,desc->passActivation.activation), + gpoolToPassMul2(cudaHandles,&desc->gpoolToPassMul2,false) + { + } + + ~PolicyHead() + { + } + + size_t requiredWorkspaceBytes( + CudaHandles* cudaHandles, + int batchSize + ) const { + size_t bytes = 0; + size_t b; + + b = p1Conv.requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + b = g1Conv.requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + b = gpoolToBiasMul.requiredWorkspaceBytes(cudaHandles); + bytes = std::max(bytes,b); + b = p2Conv.requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + b = gpoolToPassMul.requiredWorkspaceBytes(cudaHandles); + bytes = std::max(bytes,b); + b = gpoolToPassMul2.requiredWorkspaceBytes(cudaHandles); + bytes = std::max(bytes,b); + b = sizeof(float)*batchSize*g1Channels*nnXLen*nnYLen; + bytes = std::max(bytes,b); + + return bytes; + } + + void apply( + CudaHandles* cudaHandles, + ScratchBuffers* scratch, + int batchSize, + void* maskBuf, + float* maskFloatBuf, + float* maskSumBuf, + void* trunkBuf, + float* policyPassBuf, + float* policyBuf, + void* workspaceBuf, + size_t workspaceBytes + ) const { + + SizedBuf p1Out(scratch->allocator, scratch->getBufSizeXYFloat(p1Channels)); //Need to hold floats, not just halfs + SizedBuf p1Out2(scratch->allocator, scratch->getBufSizeXYFloat(p1Channels)); //Need to hold floats, not just halfs + SizedBuf g1Out(scratch->allocator, scratch->getBufSizeXY(g1Channels)); + SizedBuf g1Out2(scratch->allocator, scratch->getBufSizeXY(g1Channels)); + SizedBuf g1Concat(scratch->allocator, scratch->getBufSizeFloat(g1Channels*3)); + SizedBuf g1Bias(scratch->allocator, scratch->getBufSizeFloat(p1Channels)); + SizedBuf p1Pass(scratch->allocator, scratch->getBufSizeFloat(p1Channels)); + + p1Conv.apply(cudaHandles,batchSize,false,trunkBuf,p1Out.buf,workspaceBuf,workspaceBytes); + g1Conv.apply(cudaHandles,batchSize,false,trunkBuf,g1Out.buf,workspaceBuf,workspaceBytes); + g1BN.apply(cudaHandles,batchSize,g1Out.buf,maskBuf,g1Out2.buf); + + if(!usingFP16) { + if(!usingNHWC) + customCudaPoolRowsGPoolNCHW((const float*)g1Out2.buf,(float*)g1Concat.buf,batchSize,g1Channels,nnXLen*nnYLen,maskFloatBuf,maskSumBuf); + else + customCudaPoolRowsGPoolNHWC((const float*)g1Out2.buf,(float*)g1Concat.buf,batchSize,nnXLen*nnYLen,g1Channels,maskFloatBuf,maskSumBuf); + CUDA_ERR(name.c_str(),hipPeekAtLastError()); + } + else { + customCudaCopyFromHalf((const half*)g1Out2.buf,(float*)workspaceBuf,batchSize*g1Channels*nnXLen*nnYLen); + CUDA_ERR(name.c_str(),hipPeekAtLastError()); + if(!usingNHWC) + customCudaPoolRowsGPoolNCHW((const float*)workspaceBuf,(float*)g1Concat.buf,batchSize,g1Channels,nnXLen*nnYLen,maskFloatBuf,maskSumBuf); + else + customCudaPoolRowsGPoolNHWC((const float*)workspaceBuf,(float*)g1Concat.buf,batchSize,nnXLen*nnYLen,g1Channels,maskFloatBuf,maskSumBuf); + CUDA_ERR(name.c_str(),hipPeekAtLastError()); + } + + gpoolToBiasMul.apply(cudaHandles,scratch,batchSize,g1Concat.buf,g1Bias.buf,workspaceBuf,workspaceBytes); + + #ifdef DEBUG_INTERMEDIATE_VALUES + CudaUtils::debugPrint4D(string("p1 pre-gpool-sum"), p1Out.buf, batchSize, p1Channels, nnXLen, nnYLen, usingNHWC, usingFP16); + CudaUtils::debugPrint4D(string("g1 pre-gpool"), g1Out.buf, batchSize, g1Channels, nnXLen, nnYLen, usingNHWC, usingFP16); + CudaUtils::debugPrint2D(string("g1 pooled"), g1Concat.buf, batchSize, g1Channels*3, false); + CudaUtils::debugPrint2D(string("g1 biases"), g1Bias.buf, batchSize, p1Channels, false); + #endif + + float* p1OutBufA; + float* p1OutBufB; + if(!usingFP16) { + p1OutBufA = (float*)p1Out.buf; + p1OutBufB = (float*)p1Out2.buf; + } + else { + customCudaCopyFromHalf((const half*)p1Out.buf,(float*)p1Out2.buf,batchSize*p1Channels*nnXLen*nnYLen); + CUDA_ERR(name.c_str(),hipPeekAtLastError()); + p1OutBufA = (float*)p1Out2.buf; + p1OutBufB = (float*)p1Out.buf; + } + + if(!usingNHWC) + customCudaAddNCBiasInplaceNCHW(p1OutBufA,(float*)g1Bias.buf,batchSize,p1Channels,nnXLen*nnYLen); + else + customCudaAddNCBiasInplaceNHWC(p1OutBufA,(float*)g1Bias.buf,batchSize,nnXLen*nnYLen,p1Channels); + CUDA_ERR(name.c_str(),hipPeekAtLastError()); + + p1BN.apply(cudaHandles,batchSize,p1OutBufA,maskFloatBuf,p1OutBufB); + p2Conv.apply(cudaHandles,batchSize,false,p1OutBufB,(float*)policyBuf,workspaceBuf,workspaceBytes); + + if(modelVersion >= 15) { + gpoolToPassMul.apply(cudaHandles,scratch,batchSize,g1Concat.buf,p1Pass.buf,workspaceBuf,workspaceBytes); + gpoolToPassBias.apply(cudaHandles,batchSize,p1Pass.buf); + gpoolToPassMul2.apply(cudaHandles,scratch,batchSize,p1Pass.buf,policyPassBuf,workspaceBuf,workspaceBytes); + } + else { + gpoolToPassMul.apply(cudaHandles,scratch,batchSize,g1Concat.buf,policyPassBuf,workspaceBuf,workspaceBytes); + } + + #ifdef DEBUG_INTERMEDIATE_VALUES + CudaUtils::debugPrint4D(string("p1 after-gpool-sum"), p1OutBufA, batchSize, p1Channels, nnXLen, nnYLen, usingNHWC, false); + CudaUtils::debugPrint2D(string("policypass"), policyPassBuf, batchSize, 1, false); + CudaUtils::debugPrint4D(string("policy"), policyBuf, batchSize, p2Channels, nnXLen, nnYLen, usingNHWC, false); + #endif + + } + +}; + +//------------------------------------------------------------------------------ + +struct ValueHead { + const string name; + const int modelVersion; + const int nnXLen; + const int nnYLen; + const int v1Channels; + const int v2Channels; + const int valueChannels; + const int scoreValueChannels; + const int ownershipChannels; + const bool usingFP16; + const bool usingNHWC; + + const ConvLayer v1Conv; + const BatchNormLayer v1BN; + const MatMulLayer v2Mul; + const MatBiasLayer v2Bias; + const MatMulLayer v3Mul; + const MatBiasLayer v3Bias; + const MatMulLayer sv3Mul; + const MatBiasLayer sv3Bias; + const ConvLayer vOwnershipConv; + + ValueHead() = delete; + ValueHead(const ValueHead&) = delete; + ValueHead& operator=(const ValueHead&) = delete; + + ValueHead( + CudaHandles* cudaHandles, + CudnnManager* manager, + const ValueHeadDesc* desc, + int nnX, + int nnY, + bool useFP16, + bool useNHWC + ) : + name(desc->name), + modelVersion(desc->modelVersion), + nnXLen(nnX), + nnYLen(nnY), + v1Channels(desc->v1Conv.outChannels), + v2Channels(desc->v2Mul.outChannels), + valueChannels(desc->v3Mul.outChannels), + scoreValueChannels(desc->sv3Mul.outChannels), + ownershipChannels(desc->vOwnershipConv.outChannels), + usingFP16(useFP16), + usingNHWC(useNHWC), + v1Conv(cudaHandles,manager,&desc->v1Conv,useFP16,useNHWC), + v1BN(cudaHandles,&desc->v1BN,&desc->v1Activation,nnX,nnY,useFP16,useNHWC), + v2Mul(cudaHandles,&desc->v2Mul,false), + v2Bias(cudaHandles,&desc->v2Bias,false,desc->v2Activation.activation), + v3Mul(cudaHandles,&desc->v3Mul,false), + v3Bias(cudaHandles,&desc->v3Bias,false,ACTIVATION_IDENTITY), + sv3Mul(cudaHandles,&desc->sv3Mul,false), + sv3Bias(cudaHandles,&desc->sv3Bias,false,ACTIVATION_IDENTITY), + vOwnershipConv(cudaHandles,manager,&desc->vOwnershipConv,useFP16,useNHWC) + { + } + + ~ValueHead() + { + } + + size_t requiredWorkspaceBytes( + CudaHandles* cudaHandles, + int batchSize + ) const { + size_t bytes = 0; + size_t b; + + b = v1Conv.requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + b = v2Mul.requiredWorkspaceBytes(cudaHandles); + bytes = std::max(bytes,b); + b = v3Mul.requiredWorkspaceBytes(cudaHandles); + bytes = std::max(bytes,b); + b = sizeof(float)*batchSize*v1Channels*nnXLen*nnYLen; + bytes = std::max(bytes,b); + + b = sv3Mul.requiredWorkspaceBytes(cudaHandles); + bytes = std::max(bytes,b); + b = vOwnershipConv.requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + b = sizeof(float)*batchSize*ownershipChannels*nnXLen*nnYLen; + bytes = std::max(bytes,b); + + return bytes; + } + + + void apply( + CudaHandles* cudaHandles, + ScratchBuffers* scratch, + int batchSize, + void* maskBuf, + float* maskSumBuf, + void* trunkBuf, + float* valueBuf, + float* scoreValueBuf, + void* ownershipBuf, + void* workspaceBuf, + size_t workspaceBytes + ) const { + SizedBuf v1Out(scratch->allocator, scratch->getBufSizeXY(v1Channels)); + SizedBuf v1Out2(scratch->allocator, scratch->getBufSizeXY(v1Channels)); + SizedBuf v1Mean(scratch->allocator, scratch->getBufSizeFloat(v1Channels*3)); + SizedBuf v2Out(scratch->allocator, scratch->getBufSizeFloat(v2Channels)); + SizedBuf ownershipScratch(scratch->allocator, scratch->getBufSizeXYFloat(ownershipChannels)); + + v1Conv.apply(cudaHandles,batchSize,false,trunkBuf,v1Out.buf,workspaceBuf,workspaceBytes); + v1BN.apply(cudaHandles,batchSize,v1Out.buf,maskBuf,v1Out2.buf); + + void* bufToBePooled = v1Out2.buf; + if(usingFP16) { + customCudaCopyFromHalf((const half*)v1Out2.buf,(float*)workspaceBuf,batchSize*v1Channels*nnXLen*nnYLen); + CUDA_ERR(name.c_str(),hipPeekAtLastError()); + bufToBePooled = workspaceBuf; + } + + if(!usingNHWC) + customCudaValueHeadPoolNCHW((float*)bufToBePooled,(float*)v1Mean.buf,batchSize,v1Channels,nnXLen*nnYLen,maskSumBuf); + else + customCudaValueHeadPoolNHWC((const float*)bufToBePooled,(float*)v1Mean.buf,batchSize,nnXLen*nnYLen,v1Channels,maskSumBuf); + CUDA_ERR(name.c_str(),hipPeekAtLastError()); + + v2Mul.apply(cudaHandles,scratch,batchSize,v1Mean.buf,v2Out.buf,workspaceBuf,workspaceBytes); + v2Bias.apply(cudaHandles,batchSize,v2Out.buf); + v3Mul.apply(cudaHandles,scratch,batchSize,v2Out.buf,valueBuf,workspaceBuf,workspaceBytes); + v3Bias.apply(cudaHandles,batchSize,valueBuf); + + sv3Mul.apply(cudaHandles,scratch,batchSize,v2Out.buf,scoreValueBuf,workspaceBuf,workspaceBytes); + sv3Bias.apply(cudaHandles,batchSize,scoreValueBuf); + + #ifdef DEBUG_INTERMEDIATE_VALUES + CudaUtils::debugPrint4D(string("v1"), v1Out.buf, batchSize, v1Channels, nnXLen, nnYLen, usingNHWC, usingFP16); + CudaUtils::debugPrint2D(string("v1 pooled"), v1Mean.buf, batchSize, v1Channels, false); + CudaUtils::debugPrint2D(string("v2"), v2Out.buf, batchSize, v1Channels, false); + #endif + + if(!usingFP16) { + vOwnershipConv.apply(cudaHandles,batchSize,false,v1Out2.buf,ownershipBuf,workspaceBuf,workspaceBytes); + } + else { + vOwnershipConv.apply(cudaHandles,batchSize,false,v1Out2.buf,ownershipScratch.buf,workspaceBuf,workspaceBytes); + customCudaCopyFromHalf((const half*)ownershipScratch.buf,(float*)ownershipBuf,batchSize*ownershipChannels*nnXLen*nnYLen); + CUDA_ERR("vOwnership copy",hipPeekAtLastError()); + } + + } + +}; + +//------------------------------------------------------------------------------ + +struct Model { + const string name; + const int modelVersion; + const int maxBatchSize; + const int nnXLen; + const int nnYLen; + const int numInputChannels; + const int numInputGlobalChannels; + const int numInputMetaChannels; + const int numPolicyChannels; + const int numValueChannels; + const int numScoreValueChannels; + const int numOwnershipChannels; + const bool usingFP16; + const bool usingNHWC; + const bool inputsUsingNHWC; + + std::unique_ptr trunk; + std::unique_ptr policyHead; + std::unique_ptr valueHead; + std::unique_ptr manager; + + Model() = delete; + Model(const Model&) = delete; + Model& operator=(const Model&) = delete; + + Model( + CudaHandles* cudaHandles, + const ModelDesc* desc, + int maxBatchSz, + int nnX, + int nnY, + bool inputsUseNHWC, + bool useFP16, + bool useNHWC + ) : + name(desc->name), + modelVersion(desc->modelVersion), + maxBatchSize(maxBatchSz), + nnXLen(nnX), + nnYLen(nnY), + numInputChannels(desc->numInputChannels), + numInputGlobalChannels(desc->numInputGlobalChannels), + numInputMetaChannels(desc->numInputMetaChannels), + numPolicyChannels(desc->numPolicyChannels), + numValueChannels(desc->numValueChannels), + numScoreValueChannels(desc->numScoreValueChannels), + numOwnershipChannels(desc->numOwnershipChannels), + usingFP16(useFP16), + usingNHWC(useNHWC), + inputsUsingNHWC(inputsUseNHWC) + { + if(nnXLen > NNPos::MAX_BOARD_LEN) + throw StringError(Global::strprintf("nnXLen (%d) is greater than NNPos::MAX_BOARD_LEN (%d)", + nnXLen, NNPos::MAX_BOARD_LEN + )); + if(nnYLen > NNPos::MAX_BOARD_LEN) + throw StringError(Global::strprintf("nnYLen (%d) is greater than NNPos::MAX_BOARD_LEN (%d)", + nnYLen, NNPos::MAX_BOARD_LEN + )); + + int numFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion); + if(numInputChannels != numFeatures) + throw StringError(Global::strprintf("Neural net numInputChannels (%d) was not the expected number based on version (%d)", + numInputChannels, numFeatures + )); + int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion); + if(numInputGlobalChannels != numGlobalFeatures) + throw StringError(Global::strprintf("Neural net numInputGlobalChannels (%d) was not the expected number based on version (%d)", + numInputGlobalChannels, numGlobalFeatures + )); + if(numInputMetaChannels > 0) { + if(numInputMetaChannels != SGFMetadata::METADATA_INPUT_NUM_CHANNELS) + throw StringError(Global::strprintf("Neural net numInputMetaChannels (%d) was not the expected number (%d)", + numInputMetaChannels, SGFMetadata::METADATA_INPUT_NUM_CHANNELS + )); + } + + CudaUtils::checkBufferSize(maxBatchSize,nnXLen,nnYLen,numInputChannels); + CudaUtils::checkBufferSize(maxBatchSize,nnXLen,nnYLen,numInputGlobalChannels); + CudaUtils::checkBufferSize(maxBatchSize,nnXLen,nnYLen,numInputMetaChannels); + CudaUtils::checkBufferSize(maxBatchSize,nnXLen,nnYLen,numPolicyChannels); + CudaUtils::checkBufferSize(maxBatchSize,nnXLen,nnYLen,numValueChannels); + CudaUtils::checkBufferSize(maxBatchSize,nnXLen,nnYLen,numScoreValueChannels); + CudaUtils::checkBufferSize(maxBatchSize,nnXLen,nnYLen,numOwnershipChannels); + + manager = std::make_unique(name, maxBatchSize, nnXLen, nnYLen); + trunk = std::make_unique(cudaHandles,manager.get(),&desc->trunk,nnXLen,nnYLen,inputsUseNHWC,useFP16,useNHWC); + policyHead = std::make_unique(cudaHandles,manager.get(),&desc->policyHead,nnXLen,nnYLen,useFP16,useNHWC); + valueHead = std::make_unique(cudaHandles,manager.get(),&desc->valueHead,nnXLen,nnYLen,useFP16,useNHWC); + } + + ~Model() + { + } + + size_t requiredWorkspaceBytes( + CudaHandles* cudaHandles, + int batchSize + ) const { + size_t bytes = 0; + size_t b; + + b = trunk->requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + b = policyHead->requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + b = valueHead->requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + + return bytes; + } + + void apply( + CudaHandles* cudaHandles, + ScratchBuffers* scratch, + int batchSize, + bool requireExactNNLen, + + void* inputBuf, + void* inputGlobalBuf, + void* inputMetaBuf, + + float* policyPassBuf, + float* policyBuf, + + float* valueBuf, + float* scoreValueBuf, + void* ownershipBuf, + + void* workspaceBuf, + size_t workspaceBytes + ) const { + SizedBuf mask(scratch->allocator, scratch->getBufSizeXY(1)); + SizedBuf maskFloat(scratch->allocator, scratch->getBufSizeXYFloat(1)); + SizedBuf maskSum(scratch->allocator, scratch->getBufSizeFloat(1)); + + void* maskBuf = mask.buf; + float* maskFloatBuf = (float*)maskFloat.buf; + float* maskSumBuf = (float*)maskSum.buf; + + if(!usingFP16) { + if(inputsUsingNHWC) + customCudaChannel0ExtractNHWC((const float*)inputBuf, (float*)maskBuf, batchSize, nnXLen*nnYLen, numInputChannels); + else + customCudaChannel0ExtractNCHW((const float*)inputBuf, (float*)maskBuf, batchSize, numInputChannels, nnXLen*nnYLen); + CUDA_ERR("modelExtractMask",hipPeekAtLastError()); + } + else { + if(inputsUsingNHWC) + customCudaChannel0ExtractNHWC((const half*)inputBuf, (half*)maskBuf, batchSize, nnXLen*nnYLen, numInputChannels); + else + customCudaChannel0ExtractNCHW((const half*)inputBuf, (half*)maskBuf, batchSize, numInputChannels, nnXLen*nnYLen); + CUDA_ERR("modelExtractMask",hipPeekAtLastError()); + } + + fillMaskFloatBufAndMaskSumBuf(maskBuf,maskFloatBuf,maskSumBuf,usingFP16,batchSize,nnXLen,nnYLen); + + //Don't do any masking if we know the board is exactly the desired size + if(requireExactNNLen) { + //Set to NULL to signal downstream that this buf doesn't need to be used + maskBuf = NULL; + maskFloatBuf = NULL; + //The global pooling structures need this no matter what, for normalizing based on this and its sqrt. + //maskSumBuf = NULL; + } + + #ifdef DEBUG_INTERMEDIATE_VALUES + CudaUtils::debugPrint4D(string("Initial bin features"), inputBuf, batchSize, trunk->initialConv->inChannels, nnXLen, nnYLen, inputsUsingNHWC, usingFP16); + CudaUtils::debugPrint2D(string("Initial global features"), inputGlobalBuf, batchSize, trunk->initialMatMul->inChannels, usingFP16); + if(trunk->sgfMetadataEncoder != nullptr) { + assert(inputMetaBuf != NULL); + CudaUtils::debugPrint2D(string("Initial meta features"), inputMetaBuf, batchSize, trunk->sgfMetadataEncoder->mul1.inChannels, usingFP16); + } + #endif + + SizedBuf trunkBuf(scratch->allocator, scratch->getBufSizeXY(trunk->trunkNumChannels)); + + trunk->apply( + cudaHandles, + scratch, + batchSize, + inputBuf, + inputGlobalBuf, + inputMetaBuf, + maskBuf, + maskSumBuf, + trunkBuf.buf, + workspaceBuf, + workspaceBytes + ); + policyHead->apply( + cudaHandles, + scratch, + batchSize, + maskBuf, + maskFloatBuf, + maskSumBuf, + trunkBuf.buf, + policyPassBuf, + policyBuf, + workspaceBuf, + workspaceBytes + ); + valueHead->apply( + cudaHandles, + scratch, + batchSize, + maskBuf, + maskSumBuf, + trunkBuf.buf, + valueBuf, + scoreValueBuf, + ownershipBuf, + workspaceBuf, + workspaceBytes + ); + } + +}; + + +//------------------------------------------------------------------------------ + +struct LoadedModel { + ModelDesc modelDesc; + + LoadedModel(const string& fileName, const string& expectedSha256) { + ModelDesc::loadFromFileMaybeGZipped(fileName,modelDesc,expectedSha256); + modelDesc.applyScale8ToReduceActivations(); + } + + LoadedModel() = delete; + LoadedModel(const LoadedModel&) = delete; + LoadedModel& operator=(const LoadedModel&) = delete; +}; + +LoadedModel* NeuralNet::loadModelFile(const string& file, const string& expectedSha256) { + LoadedModel* loadedModel = new LoadedModel(file,expectedSha256); + return loadedModel; +} + +void NeuralNet::freeLoadedModel(LoadedModel* loadedModel) { + delete loadedModel; +} + +const ModelDesc& NeuralNet::getModelDesc(const LoadedModel* loadedModel) { + return loadedModel->modelDesc; +} + +//------------------------------------------------------------------------------ + +struct Buffers { + //All of these are device pointers + + float* inputBufFloat; + void* inputBuf; + float* inputGlobalBufFloat; + void* inputGlobalBuf; + float* inputMetaBufFloat; + void* inputMetaBuf; + size_t inputBufBytesFloat; + size_t inputBufBytes; + size_t inputGlobalBufBytesFloat; + size_t inputGlobalBufBytes; + size_t inputMetaBufBytesFloat; + size_t inputMetaBufBytes; + + float* policyPassBuf; + size_t policyPassBufBytes; + float* policyBuf; + size_t policyBufBytes; + + float* valueBuf; + size_t valueBufBytes; + float* scoreValueBuf; + size_t scoreValueBufBytes; + void* ownershipBuf; + size_t ownershipBufBytes; + + void* workspaceBuf; + size_t workspaceBytes; + + Buffers() = delete; + Buffers(const Buffers&) = delete; + Buffers& operator=(const Buffers&) = delete; + + Buffers(CudaHandles* cudaHandles, const Model& m, const ScratchBuffers& scratch) { + size_t batchXYFloatBytes = (size_t)scratch.batchXYFloatBytes; + size_t batchFloatBytes = (size_t)scratch.batchFloatBytes; + size_t batchXYBytes = (size_t)scratch.batchXYBytes; + size_t batchBytes = (size_t)scratch.batchBytes; + + inputBufBytesFloat = m.numInputChannels * batchXYFloatBytes; + inputBufBytes = m.numInputChannels * batchXYBytes; + inputGlobalBufBytesFloat = m.numInputGlobalChannels * batchFloatBytes; + inputGlobalBufBytes = m.numInputGlobalChannels * batchBytes; + inputMetaBufBytesFloat = m.numInputMetaChannels * batchFloatBytes; + inputMetaBufBytes = m.numInputMetaChannels * batchBytes; + + CUDA_ERR("Buffers",hipMalloc(reinterpret_cast(&inputBufFloat), inputBufBytesFloat)); + CUDA_ERR("Buffers",hipMalloc(&inputBuf, inputBufBytes)); + CUDA_ERR("Buffers",hipMalloc(reinterpret_cast(&inputGlobalBufFloat), inputGlobalBufBytesFloat)); + CUDA_ERR("Buffers",hipMalloc(&inputGlobalBuf, inputGlobalBufBytes)); + if(m.numInputMetaChannels > 0) { + CUDA_ERR("Buffers",hipMalloc(reinterpret_cast(&inputMetaBufFloat), inputMetaBufBytesFloat)); + CUDA_ERR("Buffers",hipMalloc(&inputMetaBuf, inputMetaBufBytes)); + } + else { + inputMetaBufFloat = NULL; + inputMetaBuf = NULL; + } + + if(m.modelVersion >= 16) + testAssert(m.policyHead->p2Channels == 4); + else if(m.modelVersion >= 12) + testAssert(m.policyHead->p2Channels == 2); + else + testAssert(m.policyHead->p2Channels == 1); + + policyPassBufBytes = m.policyHead->p2Channels * batchFloatBytes; + CUDA_ERR("Buffers",hipMalloc(reinterpret_cast(&policyPassBuf), policyPassBufBytes)); + policyBufBytes = m.policyHead->p2Channels * batchXYFloatBytes; + CUDA_ERR("Buffers",hipMalloc(reinterpret_cast(&policyBuf), policyBufBytes)); + + valueBufBytes = m.valueHead->valueChannels * batchFloatBytes; + CUDA_ERR("Buffers",hipMalloc(reinterpret_cast(&valueBuf), valueBufBytes)); + + scoreValueBufBytes = m.valueHead->scoreValueChannels * batchFloatBytes; + CUDA_ERR("Buffers",hipMalloc(reinterpret_cast(&scoreValueBuf), scoreValueBufBytes)); + + //This buf is used for both an intermdiate fp16 result in fp16 mode, and ALSO the final fp32 output, so always must be fp32-sized + ownershipBufBytes = m.valueHead->ownershipChannels * batchXYFloatBytes; + CUDA_ERR("Buffers",hipMalloc(&ownershipBuf, ownershipBufBytes)); + + //In theory the requiredWorkspaceBytes calls could give us values non-monotone in batch size + //such as if the convolution algorithm changes between batch size 1 and larger. + //So we call it for all the batch sizes. + size_t bytes = 0; + size_t b; + for(int batchSize = 1; batchSize <= m.maxBatchSize; batchSize++) { + b = m.requiredWorkspaceBytes(cudaHandles,batchSize); + bytes = std::max(bytes,b); + } + + CUDA_ERR("Buffers",hipMalloc(&workspaceBuf, bytes)); + workspaceBytes = bytes; + } + + ~Buffers() { + hipFree(inputBufFloat); + hipFree(inputBuf); + hipFree(inputGlobalBufFloat); + hipFree(inputGlobalBuf); + if(inputMetaBufFloat != NULL) + hipFree(inputMetaBufFloat); + if(inputMetaBuf != NULL) + hipFree(inputMetaBuf); + + hipFree(policyPassBuf); + hipFree(policyBuf); + + hipFree(valueBuf); + hipFree(scoreValueBuf); + hipFree(ownershipBuf); + + hipFree(workspaceBuf); + } + +}; + +//------------------------------------------------------------------------------ + +struct ComputeContext { + int nnXLen; + int nnYLen; + enabled_t useFP16Mode; + enabled_t useNHWCMode; +}; + +ComputeContext* NeuralNet::createComputeContext( + const std::vector& gpuIdxs, + Logger* logger, + int nnXLen, + int nnYLen, + const string& openCLTunerFile, + const string& homeDataDirOverride, + bool openCLReTunePerBoardSize, + enabled_t useFP16Mode, + enabled_t useNHWCMode, + const LoadedModel* loadedModel +) { + (void)gpuIdxs; + (void)logger; + (void)openCLTunerFile; + (void)homeDataDirOverride; + (void)openCLReTunePerBoardSize; + (void)loadedModel; + + ComputeContext* context = new ComputeContext(); + context->nnXLen = nnXLen; + context->nnYLen = nnYLen; + context->useFP16Mode = useFP16Mode; + context->useNHWCMode = useNHWCMode; + return context; +} + +void NeuralNet::freeComputeContext(ComputeContext* computeContext) { + delete computeContext; +} + +//------------------------------------------------------------------------------ + +struct ComputeHandle { + std::unique_ptr cudaHandles; + std::unique_ptr model; + std::unique_ptr scratch; + std::unique_ptr buffers; + const bool usingFP16; + const int nnXLen; + const int nnYLen; + const bool requireExactNNLen; + const bool inputsUseNHWC; + const bool usingNHWC; + + ComputeHandle( + const ComputeContext* context, + const LoadedModel* loadedModel, + int majorComputeCapability, + int minorComputeCapability, + int maxBatchSize, + bool requireExactNNLen_, + bool inputsUseNHWC_, + bool useFP16, + bool useNHWC + ) : + usingFP16(useFP16), + nnXLen(context->nnXLen), + nnYLen(context->nnYLen), + requireExactNNLen(requireExactNNLen_), + inputsUseNHWC(inputsUseNHWC_), + usingNHWC(useNHWC) + { + cudaHandles = std::make_unique(majorComputeCapability,minorComputeCapability); + model = std::make_unique( + cudaHandles.get(), &(loadedModel->modelDesc), maxBatchSize, + nnXLen, nnYLen, inputsUseNHWC, useFP16, useNHWC + ); + scratch = std::make_unique(maxBatchSize, nnXLen, nnYLen, useFP16); + buffers = std::make_unique(cudaHandles.get(), *model, *scratch); + + //Synchronize after creating buffers and copying all the weights, just in case + CUDA_ERR("ComputeHandle", hipDeviceSynchronize()); + } + ~ComputeHandle() { + } + + ComputeHandle() = delete; + ComputeHandle(const ComputeHandle&) = delete; + ComputeHandle& operator=(const ComputeHandle&) = delete; +}; + +ComputeHandle* NeuralNet::createComputeHandle( + ComputeContext* context, + const LoadedModel* loadedModel, + Logger* logger, + int maxBatchSize, + bool requireExactNNLen, + bool inputsUseNHWC, + int gpuIdxForThisThread, + int serverThreadIdx +) { + //Use whatever CUDA believes GPU 0 to be. + if(gpuIdxForThisThread == -1) + gpuIdxForThisThread = 0; + + CUDA_ERR("createComputeHandle",hipSetDevice(gpuIdxForThisThread)); + + hipDeviceProp_t prop; + hipGetDeviceProperties(&prop,gpuIdxForThisThread); + + bool useFP16 = false; + bool useNHWC = false; + if(context->useFP16Mode == enabled_t::True || context->useFP16Mode == enabled_t::Auto) + useFP16 = true; + + if(logger != NULL) { + logger->write( + "ROCm backend thread " + Global::intToString(serverThreadIdx) + ": Found GPU " + string(prop.name) + + " memory " + Global::uint64ToString(prop.totalGlobalMem) + + " compute capability major " + Global::intToString(prop.major) + + " minor " + Global::intToString(prop.minor) + ); + logger->write( + "ROCm backend thread " + Global::intToString(serverThreadIdx) + ": Model version " + Global::intToString(loadedModel->modelDesc.modelVersion) + + " useFP16 = " + Global::boolToString(useFP16) + + " useNHWC = " + Global::boolToString(useNHWC) + ); + logger->write( + "ROCm backend thread " + Global::intToString(serverThreadIdx) + ": Model name: " + loadedModel->modelDesc.name + ); + logger->write( + "MIOpen finding convolution algorithms for GPU " + string(prop.name) + ". This may take a while, please wait......" + ); + } + + ComputeHandle* gpuHandle = new ComputeHandle( + context,loadedModel,prop.major,prop.minor,maxBatchSize,requireExactNNLen,inputsUseNHWC,useFP16,useNHWC + ); + return gpuHandle; +} + +void NeuralNet::freeComputeHandle(ComputeHandle* gpuHandle) { + delete gpuHandle; +} + +bool NeuralNet::isUsingFP16(const ComputeHandle* handle) { + return handle->usingFP16; +} + +//------------------------------------------------------------------------------ + +void NeuralNet::printDevices() { + int numDevices = 0; + hipGetDeviceCount(&numDevices); + for(int i = 0; imodelDesc; + + maxBatchSize = maxBatchSz; + singleInputElts = (size_t)m.numInputChannels * nnXLen * nnYLen; + singleInputBytes = (size_t)m.numInputChannels * nnXLen * nnYLen * sizeof(float); + singleInputGlobalElts = (size_t)m.numInputGlobalChannels; + singleInputGlobalBytes = (size_t)m.numInputGlobalChannels * sizeof(float); + singleInputMetaElts = (size_t)m.numInputMetaChannels; + singleInputMetaBytes = (size_t)m.numInputMetaChannels * sizeof(float); + singlePolicyPassResultElts = (size_t)(m.numPolicyChannels); + singlePolicyPassResultBytes = (size_t)(m.numPolicyChannels) * sizeof(float); + singlePolicyResultElts = (size_t)(m.numPolicyChannels * nnXLen * nnYLen); + singlePolicyResultBytes = (size_t)(m.numPolicyChannels * nnXLen * nnYLen) * sizeof(float); + singleValueResultElts = (size_t)m.numValueChannels; + singleValueResultBytes = (size_t)m.numValueChannels * sizeof(float); + singleScoreValueResultElts = (size_t)m.numScoreValueChannels; + singleScoreValueResultBytes = (size_t)m.numScoreValueChannels * sizeof(float); + singleOwnershipResultElts = (size_t)m.numOwnershipChannels * nnXLen * nnYLen; + singleOwnershipResultBytes = (size_t)m.numOwnershipChannels * nnXLen * nnYLen * sizeof(float); + + assert(NNModelVersion::getNumSpatialFeatures(m.modelVersion) == m.numInputChannels); + assert(NNModelVersion::getNumGlobalFeatures(m.modelVersion) == m.numInputGlobalChannels); + if(m.numInputMetaChannels > 0) { + assert(SGFMetadata::METADATA_INPUT_NUM_CHANNELS == m.numInputMetaChannels); + } + + userInputBufferBytes = (size_t)m.numInputChannels * maxBatchSize * nnXLen * nnYLen * sizeof(float); + userInputGlobalBufferBytes = (size_t)m.numInputGlobalChannels * maxBatchSize * sizeof(float); + userInputMetaBufferBytes = (size_t)m.numInputMetaChannels * maxBatchSize * sizeof(float); + policyPassResultBufferBytes = (size_t)maxBatchSize * m.numPolicyChannels * sizeof(float); + policyResultBufferBytes = (size_t)maxBatchSize * m.numPolicyChannels * nnXLen * nnYLen * sizeof(float); + valueResultBufferBytes = (size_t)maxBatchSize * m.numValueChannels * sizeof(float); + scoreValueResultBufferBytes = (size_t)maxBatchSize * m.numScoreValueChannels * sizeof(float); + ownershipResultBufferBytes = (size_t)maxBatchSize * nnXLen * nnYLen * m.numOwnershipChannels * sizeof(float); + + userInputBuffer = new float[(size_t)m.numInputChannels * maxBatchSize * nnXLen * nnYLen]; + userInputGlobalBuffer = new float[(size_t)m.numInputGlobalChannels * maxBatchSize]; + if(m.numInputMetaChannels > 0) + userInputMetaBuffer = new float[(size_t)m.numInputMetaChannels * maxBatchSize]; + else + userInputMetaBuffer = NULL; + + policyPassResults = new float[(size_t)maxBatchSize * m.numPolicyChannels]; + policyResults = new float[(size_t)maxBatchSize * m.numPolicyChannels * nnXLen * nnYLen]; + valueResults = new float[(size_t)maxBatchSize * m.numValueChannels]; + + scoreValueResults = new float[(size_t)maxBatchSize * m.numScoreValueChannels]; + ownershipResults = new float[(size_t)maxBatchSize * nnXLen * nnYLen * m.numOwnershipChannels]; + } + + ~InputBuffers() { + delete[] userInputBuffer; + delete[] userInputGlobalBuffer; + if(userInputMetaBuffer != NULL) + delete[] userInputMetaBuffer; + delete[] policyPassResults; + delete[] policyResults; + delete[] valueResults; + delete[] scoreValueResults; + delete[] ownershipResults; + } + + InputBuffers() = delete; + InputBuffers(const InputBuffers&) = delete; + InputBuffers& operator=(const InputBuffers&) = delete; + +}; + +InputBuffers* NeuralNet::createInputBuffers(const LoadedModel* loadedModel, int maxBatchSize, int nnXLen, int nnYLen) { + return new InputBuffers(loadedModel,maxBatchSize,nnXLen,nnYLen); +} +void NeuralNet::freeInputBuffers(InputBuffers* inputBuffers) { + delete inputBuffers; +} + +//--------------------------------------------------------------------------------------- + + +void NeuralNet::getOutput( + ComputeHandle* gpuHandle, + InputBuffers* inputBuffers, + int numBatchEltsFilled, + NNResultBuf** inputBufs, + vector& outputs +) { + assert(numBatchEltsFilled <= inputBuffers->maxBatchSize); + assert(numBatchEltsFilled > 0); + const int batchSize = numBatchEltsFilled; + const int nnXLen = gpuHandle->nnXLen; + const int nnYLen = gpuHandle->nnYLen; + const int modelVersion = gpuHandle->model->modelVersion; + + const int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion); + const int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion); + const int numMetaFeatures = inputBuffers->singleInputMetaElts; + assert(numSpatialFeatures == gpuHandle->model->numInputChannels); + assert(numSpatialFeatures * nnXLen * nnYLen == inputBuffers->singleInputElts); + assert(numGlobalFeatures == inputBuffers->singleInputGlobalElts); + const int numPolicyChannels = gpuHandle->model->numPolicyChannels; + + for(int nIdx = 0; nIdxuserInputBuffer + (inputBuffers->singleInputElts * nIdx); + float* rowGlobalInput = inputBuffers->userInputGlobalBuffer + (inputBuffers->singleInputGlobalElts * nIdx); + float* rowMetaInput = inputBuffers->userInputMetaBuffer + (inputBuffers->singleInputMetaElts * nIdx); + + const float* rowGlobal = inputBufs[nIdx]->rowGlobalBuf.data(); + const float* rowSpatial = inputBufs[nIdx]->rowSpatialBuf.data(); + const float* rowMeta = inputBufs[nIdx]->rowMetaBuf.data(); + bool hasRowMeta = inputBufs[nIdx]->hasRowMeta; + std::copy(rowGlobal,rowGlobal+numGlobalFeatures,rowGlobalInput); + if(numMetaFeatures > 0) { + testAssert(rowMeta != NULL); + testAssert(hasRowMeta); + std::copy(rowMeta,rowMeta+numMetaFeatures,rowMetaInput); + } + else { + testAssert(!hasRowMeta); + } + SymmetryHelpers::copyInputsWithSymmetry(rowSpatial, rowSpatialInput, 1, nnYLen, nnXLen, numSpatialFeatures, gpuHandle->inputsUseNHWC, inputBufs[nIdx]->symmetry); + } + + Buffers* buffers = gpuHandle->buffers.get(); + ScratchBuffers* scratch = gpuHandle->scratch.get(); + + if(!gpuHandle->usingFP16) { + assert(inputBuffers->userInputBufferBytes == buffers->inputBufBytes); + assert(inputBuffers->userInputGlobalBufferBytes == buffers->inputGlobalBufBytes); + assert(inputBuffers->userInputMetaBufferBytes == buffers->inputMetaBufBytes); + assert(inputBuffers->policyPassResultBufferBytes == buffers->policyPassBufBytes); + assert(inputBuffers->policyResultBufferBytes == buffers->policyBufBytes); + assert(inputBuffers->valueResultBufferBytes == buffers->valueBufBytes); + assert(inputBuffers->singleInputBytes == inputBuffers->singleInputElts*4); + assert(inputBuffers->singleInputGlobalBytes == inputBuffers->singleInputGlobalElts*4); + assert(inputBuffers->singleInputMetaBytes == inputBuffers->singleInputMetaElts*4); + assert(inputBuffers->singlePolicyPassResultElts == numPolicyChannels); + assert(inputBuffers->singlePolicyPassResultBytes == numPolicyChannels * sizeof(float)); + assert(inputBuffers->singlePolicyResultElts == numPolicyChannels*nnXLen*nnYLen); + assert(inputBuffers->singlePolicyResultBytes == numPolicyChannels*nnXLen*nnYLen * sizeof(float)); + assert(inputBuffers->scoreValueResultBufferBytes == buffers->scoreValueBufBytes); + assert(inputBuffers->ownershipResultBufferBytes == buffers->ownershipBufBytes); + assert(inputBuffers->singleOwnershipResultElts == nnXLen*nnYLen); + assert(inputBuffers->singleOwnershipResultBytes == nnXLen*nnYLen * sizeof(float)); + + CUDA_ERR("getOutput",hipMemcpy(buffers->inputBuf, inputBuffers->userInputBuffer, inputBuffers->singleInputBytes*batchSize, hipMemcpyHostToDevice)); + CUDA_ERR("getOutput",hipMemcpy(buffers->inputGlobalBuf, inputBuffers->userInputGlobalBuffer, inputBuffers->singleInputGlobalBytes*batchSize, hipMemcpyHostToDevice)); + if(numMetaFeatures > 0) { + CUDA_ERR("getOutput",hipMemcpy(buffers->inputMetaBuf, inputBuffers->userInputMetaBuffer, inputBuffers->singleInputMetaBytes*batchSize, hipMemcpyHostToDevice)); + } + } + else { + assert(inputBuffers->userInputBufferBytes == buffers->inputBufBytesFloat); + assert(inputBuffers->userInputGlobalBufferBytes == buffers->inputGlobalBufBytesFloat); + assert(inputBuffers->userInputMetaBufferBytes == buffers->inputMetaBufBytesFloat); + assert(inputBuffers->policyResultBufferBytes == buffers->policyBufBytes); + assert(inputBuffers->valueResultBufferBytes == buffers->valueBufBytes); + assert(inputBuffers->userInputBufferBytes == buffers->inputBufBytes*2); + assert(inputBuffers->userInputGlobalBufferBytes == buffers->inputGlobalBufBytes*2); + assert(inputBuffers->userInputMetaBufferBytes == buffers->inputMetaBufBytes*2); + assert(inputBuffers->singleInputBytes == inputBuffers->singleInputElts*4); + assert(inputBuffers->singleInputGlobalBytes == inputBuffers->singleInputGlobalElts*4); + assert(inputBuffers->singleInputMetaBytes == inputBuffers->singleInputMetaElts*4); + assert(inputBuffers->singlePolicyPassResultElts == numPolicyChannels); + assert(inputBuffers->singlePolicyPassResultBytes == numPolicyChannels * sizeof(float)); + assert(inputBuffers->singlePolicyResultElts == numPolicyChannels*nnXLen*nnYLen); + assert(inputBuffers->singlePolicyResultBytes == numPolicyChannels*nnXLen*nnYLen * sizeof(float)); + assert(inputBuffers->scoreValueResultBufferBytes == buffers->scoreValueBufBytes); + assert(inputBuffers->ownershipResultBufferBytes == buffers->ownershipBufBytes); + assert(inputBuffers->singleOwnershipResultElts == nnXLen*nnYLen); + assert(inputBuffers->singleOwnershipResultBytes == nnXLen*nnYLen * sizeof(float)); + + CUDA_ERR("getOutput",hipMemcpy(buffers->inputBufFloat, inputBuffers->userInputBuffer, inputBuffers->singleInputBytes*batchSize, hipMemcpyHostToDevice)); + CUDA_ERR("getOutput",hipMemcpy(buffers->inputGlobalBufFloat, inputBuffers->userInputGlobalBuffer, inputBuffers->singleInputGlobalBytes*batchSize, hipMemcpyHostToDevice)); + if(numMetaFeatures > 0) { + CUDA_ERR("getOutput",hipMemcpy(buffers->inputMetaBufFloat, inputBuffers->userInputMetaBuffer, inputBuffers->singleInputMetaBytes*batchSize, hipMemcpyHostToDevice)); + } + + customCudaCopyToHalf((const float*)buffers->inputBufFloat,(half*)buffers->inputBuf,inputBuffers->singleInputElts*batchSize); + CUDA_ERR("getOutput",hipPeekAtLastError()); + customCudaCopyToHalf((const float*)buffers->inputGlobalBufFloat,(half*)buffers->inputGlobalBuf,inputBuffers->singleInputGlobalElts*batchSize); + CUDA_ERR("getOutput",hipPeekAtLastError()); + if(numMetaFeatures > 0) { + customCudaCopyToHalf((const float*)buffers->inputMetaBufFloat,(half*)buffers->inputMetaBuf,inputBuffers->singleInputMetaElts*batchSize); + CUDA_ERR("getOutput",hipPeekAtLastError()); + } + } + + gpuHandle->model->apply( + gpuHandle->cudaHandles.get(), + scratch, + batchSize, + gpuHandle->requireExactNNLen, + + buffers->inputBuf, + buffers->inputGlobalBuf, + buffers->inputMetaBuf, + + buffers->policyPassBuf, + buffers->policyBuf, + + buffers->valueBuf, + buffers->scoreValueBuf, + buffers->ownershipBuf, + + buffers->workspaceBuf, + buffers->workspaceBytes + ); + + CUDA_ERR("getOutput",hipMemcpy(inputBuffers->policyPassResults, buffers->policyPassBuf, inputBuffers->singlePolicyPassResultBytes*batchSize, hipMemcpyDeviceToHost)); + CUDA_ERR("getOutput",hipMemcpy(inputBuffers->policyResults, buffers->policyBuf, inputBuffers->singlePolicyResultBytes*batchSize, hipMemcpyDeviceToHost)); + CUDA_ERR("getOutput",hipMemcpy(inputBuffers->valueResults, buffers->valueBuf, inputBuffers->singleValueResultBytes*batchSize, hipMemcpyDeviceToHost)); + CUDA_ERR("getOutput",hipMemcpy(inputBuffers->scoreValueResults, buffers->scoreValueBuf, inputBuffers->singleScoreValueResultBytes*batchSize, hipMemcpyDeviceToHost)); + CUDA_ERR("getOutput",hipMemcpy(inputBuffers->ownershipResults, buffers->ownershipBuf, inputBuffers->singleOwnershipResultBytes*batchSize, hipMemcpyDeviceToHost)); + + assert(outputs.size() == batchSize); + + float policyProbsTmp[NNPos::MAX_NN_POLICY_SIZE]; + + for(int row = 0; row < batchSize; row++) { + NNOutput* output = outputs[row]; + assert(output->nnXLen == nnXLen); + assert(output->nnYLen == nnYLen); + float policyOptimism = (float)inputBufs[row]->policyOptimism; + + const float* policyPassSrcBuf = inputBuffers->policyPassResults + row * numPolicyChannels; + const float* policySrcBuf = inputBuffers->policyResults + row * numPolicyChannels * nnXLen * nnYLen; + float* policyProbs = output->policyProbs; + + // These are in logits, the client does the postprocessing to turn them into + // policy probabilities and white game outcome probabilities + // Also we don't fill in the nnHash here either + // Handle version >= 12 policy optimism + if(numPolicyChannels == 2 || (numPolicyChannels == 4 && modelVersion >= 16)) { + if(gpuHandle->usingNHWC) { + for(int i = 0; isymmetry); + policyProbs[nnXLen*nnYLen] = policyPassSrcBuf[0] + (policyPassSrcBuf[1] - policyPassSrcBuf[0]) * policyOptimism; + } + else { + for(int i = 0; isymmetry); + policyProbs[nnXLen*nnYLen] = policyPassSrcBuf[0] + (policyPassSrcBuf[1] - policyPassSrcBuf[0]) * policyOptimism; + } + } + else { + assert(numPolicyChannels == 1); + SymmetryHelpers::copyOutputsWithSymmetry(policySrcBuf, policyProbs, 1, nnYLen, nnXLen, inputBufs[row]->symmetry); + policyProbs[nnXLen*nnYLen] = policyPassSrcBuf[0]; + } + + int numValueChannels = gpuHandle->model->numValueChannels; + assert(numValueChannels == 3); + output->whiteWinProb = inputBuffers->valueResults[row * numValueChannels]; + output->whiteLossProb = inputBuffers->valueResults[row * numValueChannels + 1]; + output->whiteNoResultProb = inputBuffers->valueResults[row * numValueChannels + 2]; + + //As above, these are NOT actually from white's perspective, but rather the player to move. + //As usual the client does the postprocessing. + if(output->whiteOwnerMap != NULL) { + const float* ownershipSrcBuf = inputBuffers->ownershipResults + row * nnXLen * nnYLen; + assert(gpuHandle->model->numOwnershipChannels == 1); + SymmetryHelpers::copyOutputsWithSymmetry(ownershipSrcBuf, output->whiteOwnerMap, 1, nnYLen, nnXLen, inputBufs[row]->symmetry); + } + + if(modelVersion >= 9) { + int numScoreValueChannels = gpuHandle->model->numScoreValueChannels; + assert(numScoreValueChannels == 6); + output->whiteScoreMean = inputBuffers->scoreValueResults[row * numScoreValueChannels]; + output->whiteScoreMeanSq = inputBuffers->scoreValueResults[row * numScoreValueChannels + 1]; + output->whiteLead = inputBuffers->scoreValueResults[row * numScoreValueChannels + 2]; + output->varTimeLeft = inputBuffers->scoreValueResults[row * numScoreValueChannels + 3]; + output->shorttermWinlossError = inputBuffers->scoreValueResults[row * numScoreValueChannels + 4]; + output->shorttermScoreError = inputBuffers->scoreValueResults[row * numScoreValueChannels + 5]; + } + else if(modelVersion >= 8) { + int numScoreValueChannels = gpuHandle->model->numScoreValueChannels; + assert(numScoreValueChannels == 4); + output->whiteScoreMean = inputBuffers->scoreValueResults[row * numScoreValueChannels]; + output->whiteScoreMeanSq = inputBuffers->scoreValueResults[row * numScoreValueChannels + 1]; + output->whiteLead = inputBuffers->scoreValueResults[row * numScoreValueChannels + 2]; + output->varTimeLeft = inputBuffers->scoreValueResults[row * numScoreValueChannels + 3]; + output->shorttermWinlossError = 0; + output->shorttermScoreError = 0; + } + else if(modelVersion >= 4) { + int numScoreValueChannels = gpuHandle->model->numScoreValueChannels; + assert(numScoreValueChannels == 2); + output->whiteScoreMean = inputBuffers->scoreValueResults[row * numScoreValueChannels]; + output->whiteScoreMeanSq = inputBuffers->scoreValueResults[row * numScoreValueChannels + 1]; + output->whiteLead = output->whiteScoreMean; + output->varTimeLeft = 0; + output->shorttermWinlossError = 0; + output->shorttermScoreError = 0; + } + else if(modelVersion >= 3) { + int numScoreValueChannels = gpuHandle->model->numScoreValueChannels; + assert(numScoreValueChannels == 1); + output->whiteScoreMean = inputBuffers->scoreValueResults[row * numScoreValueChannels]; + //Version 3 neural nets don't have any second moment output, implicitly already folding it in, so we just use the mean squared + output->whiteScoreMeanSq = output->whiteScoreMean * output->whiteScoreMean; + output->whiteLead = output->whiteScoreMean; + output->varTimeLeft = 0; + output->shorttermWinlossError = 0; + output->shorttermScoreError = 0; + } + else { + ASSERT_UNREACHABLE; + } + } + +} + +//TESTING ---------------------------------------------------------------------------------- + + +bool NeuralNet::testEvaluateConv( + const ConvLayerDesc* desc, + int desiredBatchSize, + int nnXLen, + int nnYLen, + bool useFP16, + bool useNHWC, + const vector& inputBuffer, + vector& outputBuffer +) { + hipDeviceSynchronize(); + CudaHandles* cudaHandles = CudaHandles::cudaHandlesTesting(); + + size_t numInputFloats = (size_t)desiredBatchSize * nnXLen * nnYLen * desc->inChannels; + size_t numOutputFloats = (size_t)desiredBatchSize * nnXLen * nnYLen * desc->outChannels; + if(numInputFloats != inputBuffer.size()) + throw StringError("testEvaluateConv: unexpected input buffer size"); + + void* deviceInput; + void* deviceOutput; + CudaUtils::mallocAndCopyToDevice("deviceInput", inputBuffer.data(), numInputFloats, deviceInput, useFP16); + CudaUtils::mallocOnDevice("deviceOutput", numOutputFloats, deviceOutput, useFP16); + + int maxBatchSize = desiredBatchSize; + + CudnnManager* manager = new CudnnManager("manager",maxBatchSize,nnXLen,nnYLen); + ConvLayer* convLayer = new ConvLayer(cudaHandles,manager,desc,useFP16,useNHWC); + + size_t workspaceBytes = + convLayer->requiredWorkspaceBytes(cudaHandles,desiredBatchSize); + void* deviceWorkspace; + CUDA_ERR("deviceWorkspace",hipMalloc(&deviceWorkspace, workspaceBytes)); + + + bool accumulate = false; + convLayer->apply( + cudaHandles, + desiredBatchSize, + accumulate, + deviceInput, + deviceOutput, + deviceWorkspace, + workspaceBytes + ); + + outputBuffer.resize(numOutputFloats); + CudaUtils::expensiveCopyFromDevice("copyResultsToHost", outputBuffer.data(), numOutputFloats, deviceOutput, useFP16); + + hipFree(deviceWorkspace); + + delete convLayer; + delete manager; + hipFree(deviceInput); + hipFree(deviceOutput); + delete cudaHandles; + + return true; +} + + +bool NeuralNet::testEvaluateBatchNorm( + const BatchNormLayerDesc* desc, + int desiredBatchSize, + int nnXLen, + int nnYLen, + bool useFP16, + bool useNHWC, + const vector& inputBuffer, + const vector& maskBuffer, + vector& outputBuffer +) { + hipDeviceSynchronize(); + CudaHandles* cudaHandles = CudaHandles::cudaHandlesTesting(); + + size_t numInputFloats = (size_t)desiredBatchSize * nnXLen * nnYLen * desc->numChannels; + size_t numMaskFloats = (size_t)desiredBatchSize * nnXLen * nnYLen; + size_t numOutputFloats = (size_t)desiredBatchSize * nnXLen * nnYLen * desc->numChannels; + if(numInputFloats != inputBuffer.size()) + throw StringError("testEvaluateBatchNorm: unexpected input buffer size"); + if(numMaskFloats != maskBuffer.size()) + throw StringError("testEvaluateBatchNorm: unexpected mask buffer size"); + + ActivationLayerDesc actDesc; + actDesc.activation = ACTIVATION_IDENTITY; + + void* deviceInput; + void* deviceMask; + void* deviceOutput; + CudaUtils::mallocAndCopyToDevice("deviceInput", inputBuffer.data(), numInputFloats, deviceInput, useFP16); + CudaUtils::mallocAndCopyToDevice("deviceMask", maskBuffer.data(), numMaskFloats, deviceMask, useFP16); + CudaUtils::mallocOnDevice("deviceOutput", numOutputFloats, deviceOutput, useFP16); + + BatchNormLayer* batchNormLayer = new BatchNormLayer(cudaHandles,desc,&actDesc,nnXLen,nnYLen,useFP16,useNHWC); + + batchNormLayer->apply( + cudaHandles, + desiredBatchSize, + deviceInput, + deviceMask, + deviceOutput + ); + + outputBuffer.resize(numOutputFloats); + CudaUtils::expensiveCopyFromDevice("copyResultsToHost", outputBuffer.data(), numOutputFloats, deviceOutput, useFP16); + + delete batchNormLayer; + + hipFree(deviceInput); + hipFree(deviceMask); + hipFree(deviceOutput); + delete cudaHandles; + + return true; +} + + +bool NeuralNet::testEvaluateResidualBlock( + const ResidualBlockDesc* desc, + int desiredBatchSize, + int nnXLen, + int nnYLen, + bool useFP16, + bool useNHWC, + const vector& inputBuffer, + const vector& maskBuffer, + vector& outputBuffer +) { + hipDeviceSynchronize(); + CudaHandles* cudaHandles = CudaHandles::cudaHandlesTesting(); + + size_t numInputFloats = (size_t)desiredBatchSize * nnXLen * nnYLen * desc->preBN.numChannels; + size_t numMaskFloats = (size_t)desiredBatchSize * nnXLen * nnYLen; + size_t numOutputFloats = (size_t)desiredBatchSize * nnXLen * nnYLen * desc->finalConv.outChannels; + if(numInputFloats != inputBuffer.size()) + throw StringError("testEvaluateResidualBlock: unexpected input buffer size"); + if(numMaskFloats != maskBuffer.size()) + throw StringError("testEvaluateResidualBlock: unexpected mask buffer size"); + + ScratchBuffers* scratch = new ScratchBuffers(desiredBatchSize, nnXLen, nnYLen, useFP16); + + void* deviceInput; + void* deviceMask; + void* deviceScratch; + CudaUtils::mallocAndCopyToDevice("deviceInput", inputBuffer.data(), numInputFloats, deviceInput, useFP16); + CudaUtils::mallocAndCopyToDevice("deviceMask", maskBuffer.data(), numMaskFloats, deviceMask, useFP16); + CudaUtils::mallocOnDevice("deviceScratch", numInputFloats, deviceScratch, useFP16); + + int maxBatchSize = desiredBatchSize; + + CudnnManager* manager = new CudnnManager("manager",maxBatchSize,nnXLen,nnYLen); + ResidualBlock* residualBlock = new ResidualBlock(cudaHandles,manager,desc,nnXLen,nnYLen,useFP16,useNHWC); + + size_t workspaceBytes = + residualBlock->requiredWorkspaceBytes(cudaHandles,desiredBatchSize); + void* deviceWorkspace; + CUDA_ERR("deviceWorkspace",hipMalloc(&deviceWorkspace, workspaceBytes)); + + residualBlock->apply( + cudaHandles, + scratch, + desiredBatchSize, + deviceInput, + deviceScratch, + deviceMask, + deviceWorkspace, + workspaceBytes + ); + + outputBuffer.resize(numOutputFloats); + CudaUtils::expensiveCopyFromDevice("copyResultsToHost", outputBuffer.data(), numOutputFloats, deviceInput, useFP16); + + hipFree(deviceWorkspace); + + delete residualBlock; + delete manager; + hipFree(deviceInput); + hipFree(deviceMask); + hipFree(deviceScratch); + delete scratch; + delete cudaHandles; + + return true; +} + +bool NeuralNet::testEvaluateGlobalPoolingResidualBlock( + const GlobalPoolingResidualBlockDesc* desc, + int desiredBatchSize, + int nnXLen, + int nnYLen, + bool useFP16, + bool useNHWC, + const vector& inputBuffer, + const vector& maskBuffer, + vector& outputBuffer +) { + hipDeviceSynchronize(); + CudaHandles* cudaHandles = CudaHandles::cudaHandlesTesting(); + + size_t numInputFloats = (size_t)desiredBatchSize * nnXLen * nnYLen * desc->preBN.numChannels; + size_t numMaskFloats = (size_t)desiredBatchSize * nnXLen * nnYLen; + size_t numMaskSumFloats = (size_t)desiredBatchSize; + size_t numOutputFloats = (size_t)desiredBatchSize * nnXLen * nnYLen * desc->finalConv.outChannels; + + if(numInputFloats != inputBuffer.size()) + throw StringError("testEvaluateGlobalPoolingResidualBlock: unexpected input buffer size"); + if(numMaskFloats != maskBuffer.size()) + throw StringError("testEvaluateGlobalPoolingResidualBlock: unexpected mask buffer size"); + + ScratchBuffers* scratch = new ScratchBuffers(desiredBatchSize, nnXLen, nnYLen, useFP16); + + void* deviceInput; + void* deviceMask; + float* deviceMaskFloatOrig; + float* deviceMaskFloat; + float* deviceMaskSum; + void* deviceScratch; + + CudaUtils::mallocAndCopyToDevice("deviceInput", inputBuffer.data(), numInputFloats, deviceInput, useFP16); + CudaUtils::mallocAndCopyToDevice("deviceMask", maskBuffer.data(), numMaskFloats, deviceMask, useFP16); + CUDA_ERR("deviceMaskFloat",hipMalloc(reinterpret_cast(&deviceMaskFloat), numMaskFloats * sizeof(float))); + CUDA_ERR("deviceMaskSum",hipMalloc(reinterpret_cast(&deviceMaskSum), numMaskSumFloats * sizeof(float))); + deviceMaskFloatOrig = deviceMaskFloat; + CudaUtils::mallocOnDevice("deviceScratch", numInputFloats, deviceScratch, useFP16); + + fillMaskFloatBufAndMaskSumBuf(deviceMask, deviceMaskFloat, deviceMaskSum, useFP16, desiredBatchSize, nnXLen, nnYLen); + + int maxBatchSize = desiredBatchSize; + + CudnnManager* manager = new CudnnManager("manager",maxBatchSize,nnXLen,nnYLen); + GlobalPoolingResidualBlock* residualBlock = new GlobalPoolingResidualBlock( + cudaHandles,manager,desc,nnXLen,nnYLen,useFP16,useNHWC + ); + + size_t workspaceBytes = + residualBlock->requiredWorkspaceBytes( + cudaHandles,desiredBatchSize + ); + + void* deviceWorkspace; + CUDA_ERR("deviceWorkspace",hipMalloc(&deviceWorkspace, workspaceBytes)); + + residualBlock->apply( + cudaHandles, + scratch, + desiredBatchSize, + deviceInput, + deviceScratch, + deviceMask, + deviceMaskSum, + deviceWorkspace, + workspaceBytes + ); + + outputBuffer.resize(numOutputFloats); + CudaUtils::expensiveCopyFromDevice("copyResultsToHost", outputBuffer.data(), numOutputFloats, deviceInput, useFP16); + + hipFree(deviceWorkspace); + + delete residualBlock; + delete manager; + + hipFree(deviceInput); + hipFree(deviceMask); + hipFree(deviceMaskFloatOrig); + hipFree(deviceMaskSum); + hipFree(deviceScratch); + delete scratch; + delete cudaHandles; + + return true; +} + + +#endif // USE_ROCM_BACKEND diff --git a/cpp/neuralnet/rocmerrorcheck.h b/cpp/neuralnet/rocmerrorcheck.h new file mode 100644 index 000000000..8cb214ad7 --- /dev/null +++ b/cpp/neuralnet/rocmerrorcheck.h @@ -0,0 +1,59 @@ +#ifndef NEURALNET_ROCMERRORCHECK_H_ +#define NEURALNET_ROCMERRORCHECK_H_ + +#include "../neuralnet/rocmincludes.h" +#include "../core/global.h" + +// ---------- HIP runtime ---------- +static inline void checkCudaError(hipError_t status, + const char* opName, + const char* file, + const char* func, + int line) { + if(status != hipSuccess) + throw StringError(std::string("HIP Error @") + opName + " " + + file + ":" + func + ":" + Global::intToString(line) + + " : " + hipGetErrorString(status)); +} +#define CUDA_ERR(opName,x) checkCudaError((x),opName,__FILE__,#x,__LINE__) + +// ---------- hipBLAS ---------- +static inline const char* cublasGetErrorString(hipblasStatus_t s) { + switch(s) { + case HIPBLAS_STATUS_SUCCESS: return "HIPBLAS_STATUS_SUCCESS"; + case HIPBLAS_STATUS_ALLOC_FAILED: return "HIPBLAS_STATUS_ALLOC_FAILED"; + case HIPBLAS_STATUS_MAPPING_ERROR: return "HIPBLAS_STATUS_MAPPING_ERROR"; + case HIPBLAS_STATUS_EXECUTION_FAILED: return "HIPBLAS_STATUS_EXECUTION_FAILED"; + case HIPBLAS_STATUS_INTERNAL_ERROR: return "HIPBLAS_STATUS_INTERNAL_ERROR"; + case HIPBLAS_STATUS_INVALID_VALUE: return "HIPBLAS_STATUS_INVALID_VALUE"; + case HIPBLAS_STATUS_NOT_INITIALIZED: return "HIPBLAS_STATUS_NOT_INITIALIZED"; + case HIPBLAS_STATUS_NOT_SUPPORTED: return "HIPBLAS_STATUS_NOT_SUPPORTED"; + default: return "HIPBLAS_STATUS_UNKNOWN"; + } +} +static inline void checkCublasError(hipblasStatus_t status, + const char* opName, + const char* file, + const char* func, + int line) { + if(status != HIPBLAS_STATUS_SUCCESS) + throw StringError(std::string("hipBLAS Error @") + opName + " " + + file + ":" + func + ":" + Global::intToString(line) + + " : " + cublasGetErrorString(status)); +} +#define CUBLAS_ERR(opName,x) checkCublasError((x),opName,__FILE__,#x,__LINE__) + +// ---------- MIOpen ---------- +static inline void checkCudnnError(miopenStatus_t status, + const char* opName, + const char* file, + const char* func, + int line) { + if(status != miopenStatusSuccess) + throw StringError(std::string("MIOpen Error @") + opName + " " + + file + ":" + func + ":" + Global::intToString(line) + + " : " + miopenGetErrorString(status)); +} +#define CUDNN_ERR(opName,x) checkCudnnError((x),opName,__FILE__,#x,__LINE__) + +#endif // NEURALNET_ROCMERRORCHECK_H_ diff --git a/cpp/neuralnet/rocmhelpers.h b/cpp/neuralnet/rocmhelpers.h new file mode 100644 index 000000000..6061e6165 --- /dev/null +++ b/cpp/neuralnet/rocmhelpers.h @@ -0,0 +1,62 @@ +#ifndef NEURALNET_ROCMHELPERS_H_ +#define NEURALNET_ROCMHELPERS_H_ + +#include "../neuralnet/rocmincludes.h" +#include "../neuralnet/activations.h" + +//Given two tensors with shapes inA: [n,cA,h,w] and inB: [n,cB,h,w], that are on the GPU +//Copy them into a single tensor out: [n,cA+cB,h,w] that is also allocated on the gpu +void customCudaChannelConcat(const float* inA, const float* inB, float* out, int chwA, int chwB, int n); +void customCudaChannelConcat(const half* inA, const half* inB, half* out, int chwA, int chwB, int n); + +//Given a tensor [n,c,hw], extract out channel 0 to [n,hw] +void customCudaChannel0ExtractNCHW(const float* in, float* out, int n, int c, int hw); +void customCudaChannel0ExtractNCHW(const half* in, half* out, int n, int c, int hw); +//Given a tensor [n,hw,c], extract out channel 0 to [n,hw] +void customCudaChannel0ExtractNHWC(const float* in, float* out, int n, int hw, int c); +void customCudaChannel0ExtractNHWC(const half* in, half* out, int n, int hw, int c); + +//Given an input tensor and an output buffer of shape [n,c], fill output buffer with sum or max over c. +void customCudaPoolRowsSumNCHW(const float* in, float* out, int nSize, int cSize, int xySize, float scaleSum); +void customCudaPoolRowsSumNHWC(const float* in, float* out, int nSize, int xySize, int cSize, float scaleSum); + +//Specialized operations for value head and general global pooling. Same as the other pooling, but fusedly fills +//an output buffer of shape [n,c*3]. +void customCudaValueHeadPoolNCHW(const float* in, float* out, int nSize, int cSize, int xySize, const float* maskSum); +void customCudaValueHeadPoolNHWC(const float* in, float* out, int nSize, int xySize, int cSize, const float* maskSum); +void customCudaPoolRowsGPoolNCHW(const float* in, float* out, int nSize, int cSize, int xySize, const float* mask, const float* maskSum); +void customCudaPoolRowsGPoolNHWC(const float* in, float* out, int nSize, int xySize, int cSize, const float* mask, const float* maskSum); +void customCudaPoolRowsGPoolNCHW(const half* in, half* out, int nSize, int cSize, int xySize, const half* mask, const float* maskSum); +void customCudaPoolRowsGPoolNHWC(const half* in, half* out, int nSize, int xySize, int cSize, const half* mask, const float* maskSum); + +void customCudaCopyToHalf(const float* in, half* out, int n); +void customCudaCopyFromHalf(const half* in, float* out, int n); + +//Given a tensor, add another tensor element-wise to it (same shape). +void customCudaAddTensorsInplace(float* buf, const float* toAdd, int n); +void customCudaAddTensorsInplace(half* buf, const half* toAdd, int n); +//Given a tensor, add another tensor to it. +void customCudaAddTensorInplace(half* buf, const half* biases, int n); +//Given an input with shape [n,c] and biases of shape [c], add the biases in-place. +void customCudaAddCBiasInplaceNC(float* buf, const float* biases, int n, int c, int activation); +void customCudaAddCBiasInplaceNC(half* buf, const half* biases, int n, int c, int activation); +//Given an input with shape [n,c,xy] and biases of shape [n,c], add the biases in-place. +void customCudaAddNCBiasInplaceNCHW(float *buf, const float* biases, int nSize, int cSize, int xySize); +void customCudaAddNCBiasInplaceNCHW(half *buf, const half* biases, int nSize, int cSize, int xySize); +//Given an input with shape [n,xy,c] and biases of shape [n,c], add the biases in-place. +void customCudaAddNCBiasInplaceNHWC(float *buf, const float* biases, int nSize, int xySize, int cSize); +void customCudaAddNCBiasInplaceNHWC(half *buf, const half* biases, int nSize, int xySize, int cSize); + +//Given an input with shape [n,c,xy] and scale and biases of shape [c], multiply by scale and add the biases +//Optionally also apply an activation. +//Optionally also multiply by mask (can be null), with shape [n,xy] +void customCudaApplyCScaleBiasNCHW(const float* in, float* out, const float* scale, const float* biases, const float* mask, int n, int c, int xy, int activation); +void customCudaApplyCScaleBiasNCHW(const half* in, half* out, const half* scale, const half* biases, const half* mask, int n, int c, int xy, int activation); +//Given an input with shape [n,xy,c] and scale and biases of shape [c], multiply by scale and add the biases +//Optionally also apply relu. +//Optionally also multiply by mask (can be null), with shape [n,xy] +void customCudaApplyCScaleBiasNHWC(const float* in, float* out, const float* scale, const float* biases, const float* mask, int n, int xy, int c, int activation); +void customCudaApplyCScaleBiasNHWC(const half* in, half* out, const half* scale, const half* biases, const half* mask, int n, int xy, int c, int activation); + + +#endif // NEURALNET_ROCMHELPERS_H_ diff --git a/cpp/neuralnet/rocmhelpers.hip b/cpp/neuralnet/rocmhelpers.hip new file mode 100644 index 000000000..7db6cb032 --- /dev/null +++ b/cpp/neuralnet/rocmhelpers.hip @@ -0,0 +1,1942 @@ +#include "hip/hip_runtime.h" + +#include "../neuralnet/rocmhelpers.h" + +#include + +#if defined(__HIP_ARCH_HAS_FP16__) || (defined(__HIP_DEVICE_COMPILE__) && (__HIP_ARCH_GFX803__ || __HIP_ARCH_GFX900__)) +#define HIP_SUPPORTS_FP16 +#endif + +//TODO maybe tune this number, it varies by GPU +static const int targetNumThreads = 512; + +void splitThreadsAcrossDim01(int dim0Size, int dim1Size, int& threads0, int& blocks0, int& threads1, int& blocks1) { + if(dim0Size > targetNumThreads) { + threads0 = targetNumThreads/2; + blocks0 = (dim0Size + threads0 - 1) / threads0; + threads1 = 1; + blocks1 = dim1Size; + } + else if(dim0Size > targetNumThreads/2) { + threads0 = dim0Size; + blocks0 = 1; + threads1 = 1; + blocks1 = dim1Size; + } + else { + threads0 = dim0Size; + blocks0 = 1; + threads1 = targetNumThreads / dim0Size; + blocks1 = (dim1Size + threads1 - 1) / threads1; + } +} + +__forceinline__ __device__ float mishf(float a) { + return a * tanhf(a < 20.0f ? log1pf(expf(a)) : a); +} +__forceinline__ __device__ float mishf_scale8(float a) { + return a < 2.5f ? a * tanhf(log1pf(expf(a*8.0f))) : a; +} + +#ifdef HIP_SUPPORTS_FP16 +__forceinline__ __device__ half mishh(half h) { + float a = __half2float(h); + return __float2half(a * tanhf(a < 20.0f ? log1pf(expf(a)) : a)); +} +__forceinline__ __device__ half mishh_scale8(half h) { + float a = __half2float(h); + return __float2half(a < 2.5f ? a * tanhf(log1pf(expf(a*8.0f))) : a); +} +#endif + +//-------------------------------------------------------------------------------------------------------------- + +template +__global__ +void channelConcatKernel( + const T* inA, + const T* inB, + T* out, + int chwA, + int chwB, + int numBlocksA, + int numBlocksB, + int n +) { + if(blockIdx.x < numBlocksA) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if(index < chwA) { + int nchwA = n*chwA; + int chwOut = (chwA+chwB); + + int aIdx = index; + int outIdx = index; + while(aIdx < nchwA) { + out[outIdx] = inA[aIdx]; + aIdx += chwA; + outIdx += chwOut; + } + } + } + else { + int index = (blockIdx.x - numBlocksA) * blockDim.x + threadIdx.x; + if(index < chwB) { + int nchwB = n*chwB; + int chwOut = (chwA+chwB); + + int bIdx = index; + int outIdx = chwA+index; + while(bIdx < nchwB) { + out[outIdx] = inB[bIdx]; + bIdx += chwB; + outIdx += chwOut; + } + } + } +} + +template +void customCudaChannelConcatTemplate(const T* inA, const T* inB, T* out, int chwA, int chwB, int n) { + int blockSize = targetNumThreads; + int numBlocksA = (chwA + blockSize-1) / blockSize; + int numBlocksB = (chwB + blockSize-1) / blockSize; + int numBlocks = numBlocksA + numBlocksB; + channelConcatKernel<<>>(inA,inB,out,chwA,chwB,numBlocksA,numBlocksB,n); +} +template void customCudaChannelConcatTemplate(const float* inA, const float* inB, float* out, int chwA, int chwB, int n); +template void customCudaChannelConcatTemplate(const half* inA, const half* inB, half* out, int chwA, int chwB, int n); + +void customCudaChannelConcat(const float* inA, const float* inB, float* out, int chwA, int chwB, int n) { + customCudaChannelConcatTemplate(inA,inB,out,chwA,chwB,n); +} +void customCudaChannelConcat(const half* inA, const half* inB, half* out, int chwA, int chwB, int n) { + customCudaChannelConcatTemplate(inA,inB,out,chwA,chwB,n); +} + +//-------------------------------------------------------------------------------------------------------------- + +template +__global__ +void extractChannel0KernelNHWC(const T *in, T* out, int nhwSize, int cSize) +{ + int nhwIdx = blockIdx.x * blockDim.x + threadIdx.x; + if(nhwIdx < nhwSize) { + out[nhwIdx] = in[nhwIdx*cSize]; + } +} +template +void customCudaChannel0ExtractNHWCTemplate(const T *in, T* out, int n, int hw, int c) { + int nhw = n*hw; + int blockSize = targetNumThreads; + int numBlocks = (nhw+blockSize-1)/blockSize; + extractChannel0KernelNHWC<<>>(in,out,nhw,c); +} + +template +__global__ +void extractChannel0KernelNCHW(const T *in, T* out, int nSize, int cSize, int hwSize) +{ + int hwIdx = blockIdx.x * blockDim.x + threadIdx.x; + int nIdx = blockIdx.y * blockDim.y + threadIdx.y; + if(hwIdx < hwSize && nIdx < nSize) { + out[nIdx * hwSize + hwIdx] = in[nIdx * cSize * hwSize + hwIdx]; + } +} +template +void customCudaChannel0ExtractNCHWTemplate(const T *in, T* out, int nSize, int cSize, int hwSize) { + int hwThreads; + int hwBlocks; + int nThreads; + int nBlocks; + splitThreadsAcrossDim01(hwSize, nSize, hwThreads, hwBlocks, nThreads, nBlocks); + + if(nBlocks > 65536) + throw std::runtime_error("customCudaChannel0ExtractNCHW: nSize too large given hwSize"); + + dim3 grid(hwBlocks,nBlocks,1); + dim3 threads(hwThreads,nThreads,1); + extractChannel0KernelNCHW<<>>(in,out,nSize,cSize,hwSize); +} + +void customCudaChannel0ExtractNCHW(const float* in, float* out, int n, int c, int hw) { + customCudaChannel0ExtractNCHWTemplate(in,out,n,c,hw); +} +void customCudaChannel0ExtractNCHW(const half* in, half* out, int n, int c, int hw) { + customCudaChannel0ExtractNCHWTemplate(in,out,n,c,hw); +} +void customCudaChannel0ExtractNHWC(const float* in, float* out, int n, int hw, int c) { + customCudaChannel0ExtractNHWCTemplate(in,out,n,hw,c); +} +void customCudaChannel0ExtractNHWC(const half* in, half* out, int n, int hw, int c) { + customCudaChannel0ExtractNHWCTemplate(in,out,n,hw,c); +} + +//-------------------------------------------------------------------------------------------------------------- + +__global__ +void sumChannelsNCHWKernel(const float* in, float* out, int cSize, int xySize, float scaleSum) +{ + extern __shared__ float sumPoolNCHWShared[]; + int xyId = threadIdx.x; + int xyBlockDim = blockDim.x; + int cId = threadIdx.y; + int cBlockDim = blockDim.y; + int cIdx = blockIdx.y * cBlockDim + cId; + int nIdx = blockIdx.z; + + int xycSize = xySize*cSize; + int sharedIdx = xyId + cId * xyBlockDim; + + float acc = 0.0f; + if(cIdx < cSize) { + int xyIdx = xyId; + while(xyIdx < xySize) { + acc += in[xyIdx + cIdx * xySize + nIdx * xycSize]; + xyIdx += xyBlockDim; + } + sumPoolNCHWShared[sharedIdx] = acc; + } + __syncthreads(); + + for(int s = xyBlockDim>>1; s > 0; s >>= 1) { + if(xyId < s) { + sumPoolNCHWShared[sharedIdx] += sumPoolNCHWShared[sharedIdx + s]; + } + __syncthreads(); + } + if(xyId == 0 && cIdx < cSize) + out[cIdx + nIdx * cSize] = sumPoolNCHWShared[sharedIdx] * scaleSum; +} +__global__ +void valueHeadPoolChannelsNCHWKernel(const float* in, float* out, int nSize, int cSize, int xySize, const float* maskSum) +{ + extern __shared__ float sumPoolNCHWShared[]; + int xyId = threadIdx.x; + int xyBlockDim = blockDim.x; + int cId = threadIdx.y; + int cBlockDim = blockDim.y; + int cIdx = blockIdx.y * cBlockDim + cId; + int nIdx = blockIdx.z; + + int xycSize = xySize*cSize; + int sharedIdx = xyId + cId * xyBlockDim; + + float acc = 0.0f; + if(cIdx < cSize) { + int xyIdx = xyId; + while(xyIdx < xySize) { + acc += in[xyIdx + cIdx * xySize + nIdx * xycSize]; + xyIdx += xyBlockDim; + } + sumPoolNCHWShared[sharedIdx] = acc; + } + __syncthreads(); + + for(int s = xyBlockDim>>1; s > 0; s >>= 1) { + if(xyId < s) { + sumPoolNCHWShared[sharedIdx] += sumPoolNCHWShared[sharedIdx + s]; + } + __syncthreads(); + } + if(xyId == 0 && cIdx < cSize) { + float sum = sumPoolNCHWShared[sharedIdx]; + float div = maskSum[nIdx]; + float sqrtdiv = sqrt(div); + float mean = sum/div; + out[cIdx + nIdx * cSize*3] = mean; + out[cIdx + nIdx * cSize*3 + cSize] = mean * (sqrtdiv - 14.0f) * 0.1f; + out[cIdx + nIdx * cSize*3 + cSize*2] = mean * ((sqrtdiv - 14.0f) * (sqrtdiv - 14.0f) * 0.01f - 0.1f); + } +} +__global__ +void gPoolChannelsNCHWKernel(const float* in, float* out, int cSize, int xySize, const float* maskSum, int sharedMemElts) +{ + extern __shared__ float poolNCHWShared[]; + float* sumShared = (float*)poolNCHWShared; + float* maxShared = (float*)poolNCHWShared + sharedMemElts; + + int xyId = threadIdx.x; + int xyBlockDim = blockDim.x; + int cId = threadIdx.y; + int cBlockDim = blockDim.y; + int cIdx = blockIdx.y * cBlockDim + cId; + int nIdx = blockIdx.z; + + int xycSize = xySize*cSize; + int sharedIdx = xyId + cId * xyBlockDim; + + if(cIdx < cSize) { + float accSum = 0.0f; + float accMax = -1.0f; + int xyIdx = xyId; + while(xyIdx < xySize) { + float a = in[xyIdx + cIdx * xySize + nIdx * xycSize]; + accSum += a; + accMax = fmaxf(accMax, a); + xyIdx += xyBlockDim; + } + sumShared[sharedIdx] = accSum; + maxShared[sharedIdx] = accMax; + } + __syncthreads(); + + for(int s = xyBlockDim>>1; s > 0; s >>= 1) { + if(xyId < s) { + sumShared[sharedIdx] += sumShared[sharedIdx + s]; + maxShared[sharedIdx] = fmaxf(maxShared[sharedIdx], maxShared[sharedIdx + s]); + } + __syncthreads(); + } + if(xyId == 0 && cIdx < cSize) { + float sum = sumShared[sharedIdx]; + float div = maskSum[nIdx]; + float sqrtdiv = sqrt(div); + float mean = sum/div; + + out[cIdx + nIdx * (cSize*3)] = mean; + out[cIdx + nIdx * (cSize*3) + cSize] = mean * (sqrtdiv - 14.0f) * 0.1f; + out[cIdx + nIdx * (cSize*3) + cSize*2] = maxShared[sharedIdx]; + } +} +__global__ +void gPoolChannelsNCHWMaskKernel(const float* in, float* out, int cSize, int xySize, const float* mask, const float* maskSum, int sharedMemElts) +{ + extern __shared__ float poolNCHWShared[]; + float* sumShared = (float*)poolNCHWShared; + float* maxShared = (float*)poolNCHWShared + sharedMemElts; + + int xyId = threadIdx.x; + int xyBlockDim = blockDim.x; + int cId = threadIdx.y; + int cBlockDim = blockDim.y; + int cIdx = blockIdx.y * cBlockDim + cId; + int nIdx = blockIdx.z; + + int xycSize = xySize*cSize; + int sharedIdx = xyId + cId * xyBlockDim; + + if(cIdx < cSize) { + float accSum = 0.0f; + float accMax = -1.0f; + int xyIdx = xyId; + while(xyIdx < xySize) { + float a = in[xyIdx + cIdx * xySize + nIdx * xycSize]; + accSum += a; + // Init to -1.0 above and + mask - 1.0 is because it will effectively make all padded space into -1.0 + // which is lower than the lowest value that any current activation function will produce. + // so the max over all valid spaces will the same as the mask over all spaces including padding + // We're relying on all padded space being equal to 0 because this gpool only ever follows a BN+Activate with a mask. + accMax = fmaxf(accMax, a + (mask[xyIdx + nIdx * xySize] - 1.0f)); + xyIdx += xyBlockDim; + } + sumShared[sharedIdx] = accSum; + maxShared[sharedIdx] = accMax; + } + __syncthreads(); + + for(int s = xyBlockDim>>1; s > 0; s >>= 1) { + if(xyId < s) { + sumShared[sharedIdx] += sumShared[sharedIdx + s]; + maxShared[sharedIdx] = fmaxf(maxShared[sharedIdx], maxShared[sharedIdx + s]); + } + __syncthreads(); + } + if(xyId == 0 && cIdx < cSize) { + float sum = sumShared[sharedIdx]; + float div = maskSum[nIdx]; + float sqrtdiv = sqrt(div); + float mean = sum/div; + + out[cIdx + nIdx * (cSize*3)] = mean; + out[cIdx + nIdx * (cSize*3) + cSize] = mean * (sqrtdiv - 14.0f) * 0.1f; + out[cIdx + nIdx * (cSize*3) + cSize*2] = maxShared[sharedIdx]; + } +} + +void customCudaPoolRowsSumNCHW(const float* in, float* out, int nSize, int cSize, int xySize, float scaleSum) { + if(nSize > 65536) + throw std::runtime_error("customCudaPoolRowsSumNCHW: nSize too large"); + if(cSize > 65536) + throw std::runtime_error("customCudaPoolRowsSumNCHW: cSize too large"); + + //Use up as many threads as possible along the xy dimension. + int xyThreads = 1; + while(xyThreads < targetNumThreads && xyThreads < xySize/2) + xyThreads *= 2; + + //Distribute the extra threads along the c dimension. + int cThreads = (targetNumThreads < xyThreads) ? 1 : (targetNumThreads / xyThreads); + int cBlocks = (cSize + cThreads - 1) / cThreads; + + //We need one shared memory spot per thread + int sharedMemSize = sizeof(float) * cThreads * xyThreads; + + dim3 grid(1,cBlocks,nSize); + dim3 threads(xyThreads,cThreads,1); + sumChannelsNCHWKernel<<>>(in,out,cSize,xySize,scaleSum); +} +void customCudaValueHeadPoolNCHW(const float* in, float* out, int nSize, int cSize, int xySize, const float* maskSum) { + if(nSize > 65536) + throw std::runtime_error("customCudaValueHeadPoolNCHW: nSize too large"); + if(cSize > 65536) + throw std::runtime_error("customCudaValueHeadPoolNCHW: cSize too large"); + + //Use up as many threads as possible along the xy dimension. + int xyThreads = 1; + while(xyThreads < targetNumThreads && xyThreads < xySize/2) + xyThreads *= 2; + + //Distribute the extra threads along the c dimension. + int cThreads = (targetNumThreads < xyThreads) ? 1 : (targetNumThreads / xyThreads); + int cBlocks = (cSize + cThreads - 1) / cThreads; + + //We need one shared memory spot per thread + int sharedMemSize = sizeof(float) * cThreads * xyThreads; + + dim3 grid(1,cBlocks,nSize); + dim3 threads(xyThreads,cThreads,1); + valueHeadPoolChannelsNCHWKernel<<>>(in,out,nSize,cSize,xySize,maskSum); +} +void customCudaPoolRowsGPoolNCHW(const float* in, float* out, int nSize, int cSize, int xySize, const float* mask, const float* maskSum) { + if(nSize > 65536) + throw std::runtime_error("customCudaPoolRowsGPoolNCHW: nSize too large"); + if(cSize > 65536) + throw std::runtime_error("customCudaPoolRowsGPoolNCHW: cSize too large"); + + //Use up as many threads as possible along the xy dimension. + int xyThreads = 1; + while(xyThreads < targetNumThreads && xyThreads < xySize/2) + xyThreads *= 2; + + //Distribute the extra threads along the c dimension. + int cThreads = (targetNumThreads < xyThreads) ? 1 : (targetNumThreads / xyThreads); + int cBlocks = (cSize + cThreads - 1) / cThreads; + + //We need one shared memory spot per thread, and then we double it because we need both sum and max. + //We also make sure it's a power of two to address any alignment concerns. + int sharedMemElts = 128; + while(sharedMemElts < cThreads * xyThreads) + sharedMemElts *= 2; + int sharedMemSize = sizeof(float) * sharedMemElts * 2; + + dim3 grid(1,cBlocks,nSize); + dim3 threads(xyThreads,cThreads,1); + if(mask != NULL) + gPoolChannelsNCHWMaskKernel<<>>(in,out,cSize,xySize,mask,maskSum,sharedMemElts); + else + gPoolChannelsNCHWKernel<<>>(in,out,cSize,xySize,maskSum,sharedMemElts); +} + +//-------------------------------------------------------------------------------------------------------------- + +__global__ +void gPoolChannelsNCHWHalfKernel(const half* in, half* out, int cSize, int xySize, const float* maskSum, int sharedMemElts) +{ +#ifdef HIP_SUPPORTS_FP16 + extern __shared__ float poolNCHWShared[]; + float* sumShared = (float*)poolNCHWShared; + float* maxShared = (float*)poolNCHWShared + sharedMemElts; + + int xyId = threadIdx.x; + int xyBlockDim = blockDim.x; + int cId = threadIdx.y; + int cBlockDim = blockDim.y; + int cIdx = blockIdx.y * cBlockDim + cId; + int nIdx = blockIdx.z; + + int xycSize = xySize*cSize; + int sharedIdx = xyId + cId * xyBlockDim; + + if(cIdx < cSize) { + float accSum = 0.0f; + float accMax = -1.0f; + int xyIdx = xyId; + while(xyIdx < xySize) { + float a = __half2float(in[xyIdx + cIdx * xySize + nIdx * xycSize]); + accSum += a; + accMax = fmaxf(accMax, a); + xyIdx += xyBlockDim; + } + sumShared[sharedIdx] = accSum; + maxShared[sharedIdx] = accMax; + } + __syncthreads(); + + for(int s = xyBlockDim>>1; s > 0; s >>= 1) { + if(xyId < s) { + sumShared[sharedIdx] += sumShared[sharedIdx + s]; + maxShared[sharedIdx] = fmaxf(maxShared[sharedIdx], maxShared[sharedIdx + s]); + } + __syncthreads(); + } + if(xyId == 0 && cIdx < cSize) { + float sum = sumShared[sharedIdx]; + float div = maskSum[nIdx]; + float sqrtdiv = sqrt(div); + float mean = sum/div; + + out[cIdx + nIdx * (cSize*3)] = __float2half(mean); + out[cIdx + nIdx * (cSize*3) + cSize] = __float2half(mean * (sqrtdiv - 14.0f) * 0.1f); + out[cIdx + nIdx * (cSize*3) + cSize*2] = __float2half(maxShared[sharedIdx]); + } +#else + //Do nothing, FP16 not supported +#endif +} +__global__ +void gPoolChannelsNCHWHalfMaskKernel(const half* in, half* out, int cSize, int xySize, const half* mask, const float* maskSum, int sharedMemElts) +{ +#ifdef HIP_SUPPORTS_FP16 + extern __shared__ float poolNCHWShared[]; + float* sumShared = (float*)poolNCHWShared; + float* maxShared = (float*)poolNCHWShared + sharedMemElts; + + int xyId = threadIdx.x; + int xyBlockDim = blockDim.x; + int cId = threadIdx.y; + int cBlockDim = blockDim.y; + int cIdx = blockIdx.y * cBlockDim + cId; + int nIdx = blockIdx.z; + + int xycSize = xySize*cSize; + int sharedIdx = xyId + cId * xyBlockDim; + + if(cIdx < cSize) { + float accSum = 0.0f; + float accMax = -1.0f; + int xyIdx = xyId; + while(xyIdx < xySize) { + float a = __half2float(in[xyIdx + cIdx * xySize + nIdx * xycSize]); + accSum += a; + // Init to -1.0 above and + mask - 1.0 is because it will effectively make all padded space into -1.0 + // which is lower than the lowest value that any current activation function will produce. + // so the max over all valid spaces will the same as the mask over all spaces including padding + accMax = fmaxf(accMax, a + (__half2float(mask[xyIdx + nIdx * xySize]) - 1.0f)); + xyIdx += xyBlockDim; + } + sumShared[sharedIdx] = accSum; + maxShared[sharedIdx] = accMax; + } + __syncthreads(); + + for(int s = xyBlockDim>>1; s > 0; s >>= 1) { + if(xyId < s) { + sumShared[sharedIdx] += sumShared[sharedIdx + s]; + maxShared[sharedIdx] = fmaxf(maxShared[sharedIdx], maxShared[sharedIdx + s]); + } + __syncthreads(); + } + if(xyId == 0 && cIdx < cSize) { + float sum = sumShared[sharedIdx]; + float div = maskSum[nIdx]; + float sqrtdiv = sqrt(div); + float mean = sum/div; + + out[cIdx + nIdx * (cSize*3)] = __float2half(mean); + out[cIdx + nIdx * (cSize*3) + cSize] = __float2half(mean * (sqrtdiv - 14.0f) * 0.1f); + out[cIdx + nIdx * (cSize*3) + cSize*2] = __float2half(maxShared[sharedIdx]); + } +#else + //Do nothing, FP16 not supported +#endif +} + +void customCudaPoolRowsGPoolNCHW(const half* in, half* out, int nSize, int cSize, int xySize, const half* mask, const float* maskSum) { + if(nSize > 65536) + throw std::runtime_error("customCudaPoolRowsGPoolNCHW: nSize too large"); + if(cSize > 65536) + throw std::runtime_error("customCudaPoolRowsGPoolNCHW: cSize too large"); + + //Use up as many threads as possible along the xy dimension. + int xyThreads = 1; + while(xyThreads < targetNumThreads && xyThreads < xySize/2) + xyThreads *= 2; + + //Distribute the extra threads along the c dimension. + int cThreads = (targetNumThreads < xyThreads) ? 1 : (targetNumThreads / xyThreads); + int cBlocks = (cSize + cThreads - 1) / cThreads; + + //We need one shared memory spot per thread, and then we double it because we need both sum and max. + //We also make sure it's a power of two to address any alignment concerns. + int sharedMemElts = 128; + while(sharedMemElts < cThreads * xyThreads) + sharedMemElts *= 2; + int sharedMemSize = sizeof(float) * sharedMemElts * 2; + + dim3 grid(1,cBlocks,nSize); + dim3 threads(xyThreads,cThreads,1); + if(mask != NULL) + gPoolChannelsNCHWHalfMaskKernel<<>>(in,out,cSize,xySize,mask,maskSum,sharedMemElts); + else + gPoolChannelsNCHWHalfKernel<<>>(in,out,cSize,xySize,maskSum,sharedMemElts); +} + + + +//-------------------------------------------------------------------------------------------------------------- + +__global__ +void sumChannelsNHWCKernel(const float* in, float* out, int xySize, int cSize, float scaleSum) +{ + extern __shared__ float sumPoolNHWCShared[]; + int cId = threadIdx.x; + int cBlockDim = blockDim.x; + int xyId = threadIdx.y; + int xyBlockDim = blockDim.y; + + int cIdx = blockIdx.x * cBlockDim + cId; + int nIdx = blockIdx.z; + int sharedIdx = cId + cBlockDim * xyId; + int xycSize = xySize*cSize; + + sumPoolNHWCShared[sharedIdx] = 0; + + if(cIdx < cSize) { + int xyIdx = xyId; + while(xyIdx < xySize) { + sumPoolNHWCShared[sharedIdx] += in[cIdx + xyIdx * cSize + nIdx * xycSize]; + xyIdx += xyBlockDim; + } + } + __syncthreads(); + + for(int s = xyBlockDim>>1; s > 0; s >>= 1) { + if(xyId < s) { + sumPoolNHWCShared[sharedIdx] += sumPoolNHWCShared[sharedIdx + cBlockDim * s]; + } + __syncthreads(); + } + if(xyId == 0 && cIdx < cSize) + out[cIdx + nIdx * cSize] = sumPoolNHWCShared[sharedIdx] * scaleSum; +} +__global__ +void valueHeadPoolChannelsNHWCKernel(const float* in, float* out, int nSize, int xySize, int cSize, const float* maskSum) +{ + extern __shared__ float sumPoolNHWCShared[]; + int cId = threadIdx.x; + int cBlockDim = blockDim.x; + int xyId = threadIdx.y; + int xyBlockDim = blockDim.y; + + int cIdx = blockIdx.x * cBlockDim + cId; + int nIdx = blockIdx.z; + int sharedIdx = cId + cBlockDim * xyId; + int xycSize = xySize*cSize; + + sumPoolNHWCShared[sharedIdx] = 0; + + if(cIdx < cSize) { + int xyIdx = xyId; + while(xyIdx < xySize) { + sumPoolNHWCShared[sharedIdx] += in[cIdx + xyIdx * cSize + nIdx * xycSize]; + xyIdx += xyBlockDim; + } + } + __syncthreads(); + + for(int s = xyBlockDim>>1; s > 0; s >>= 1) { + if(xyId < s) { + sumPoolNHWCShared[sharedIdx] += sumPoolNHWCShared[sharedIdx + cBlockDim * s]; + } + __syncthreads(); + } + if(xyId == 0 && cIdx < cSize) { + float sum = sumPoolNHWCShared[sharedIdx]; + float div = maskSum[nIdx]; + float sqrtdiv = sqrt(div); + float mean = sum/div; + out[cIdx + nIdx * cSize*3] = mean; + out[cIdx + nIdx * cSize*3 + cSize] = mean * (sqrtdiv - 14.0f) * 0.1f; + out[cIdx + nIdx * cSize*3 + cSize*2] = mean * ((sqrtdiv - 14.0f) * (sqrtdiv - 14.0f) * 0.01f - 0.1f); + } +} +__global__ +void gPoolChannelsNHWCKernel(const float* in, float* out, int xySize, int cSize, const float* maskSum, int sharedMemElts) +{ + extern __shared__ float poolNHWCShared[]; + float* sumShared = (float*)poolNHWCShared; + float* maxShared = (float*)poolNHWCShared + sharedMemElts; + + int cId = threadIdx.x; + int cBlockDim = blockDim.x; + int xyId = threadIdx.y; + int xyBlockDim = blockDim.y; + + int cIdx = blockIdx.x * cBlockDim + cId; + int nIdx = blockIdx.z; + int sharedIdx = cId + cBlockDim * xyId; + int xycSize = xySize*cSize; + + sumShared[sharedIdx] = 0; + maxShared[sharedIdx] = -1.0f; + + if(cIdx < cSize) { + int xyIdx = xyId; + while(xyIdx < xySize) { + float a = in[cIdx + xyIdx * cSize + nIdx * xycSize]; + sumShared[sharedIdx] += a; + maxShared[sharedIdx] = fmaxf(maxShared[sharedIdx], a); + xyIdx += xyBlockDim; + } + } + __syncthreads(); + + for(int s = xyBlockDim>>1; s > 0; s >>= 1) { + if(xyId < s) { + sumShared[sharedIdx] += sumShared[sharedIdx + cBlockDim * s]; + maxShared[sharedIdx] = fmaxf(maxShared[sharedIdx],maxShared[sharedIdx + cBlockDim * s]); + } + __syncthreads(); + } + if(xyId == 0 && cIdx < cSize) { + float sum = sumShared[sharedIdx]; + float div = maskSum[nIdx]; + float sqrtdiv = sqrt(div); + float mean = sum/div; + + out[cIdx + nIdx * (cSize*3)] = mean; + out[cIdx + nIdx * (cSize*3) + cSize] = mean * (sqrtdiv - 14.0f) * 0.1f; + out[cIdx + nIdx * (cSize*3) + cSize*2] = maxShared[sharedIdx]; + } +} +__global__ +void gPoolChannelsNHWCMaskKernel(const float* in, float* out, int xySize, int cSize, const float* mask, const float* maskSum, int sharedMemElts) +{ + extern __shared__ float poolNHWCShared[]; + float* sumShared = (float*)poolNHWCShared; + float* maxShared = (float*)poolNHWCShared + sharedMemElts; + + int cId = threadIdx.x; + int cBlockDim = blockDim.x; + int xyId = threadIdx.y; + int xyBlockDim = blockDim.y; + + int cIdx = blockIdx.x * cBlockDim + cId; + int nIdx = blockIdx.z; + int sharedIdx = cId + cBlockDim * xyId; + int xycSize = xySize*cSize; + + sumShared[sharedIdx] = 0; + maxShared[sharedIdx] = -1.0f; + + if(cIdx < cSize) { + int xyIdx = xyId; + while(xyIdx < xySize) { + float a = in[cIdx + xyIdx * cSize + nIdx * xycSize]; + sumShared[sharedIdx] += a; + // Init to -1.0 above and + mask - 1.0 is because it will effectively make all padded space into -1.0 + // which is lower than the lowest value that any current activation function will produce. + // so the max over all valid spaces will the same as the mask over all spaces including padding + maxShared[sharedIdx] = fmaxf(maxShared[sharedIdx], a + (mask[xyIdx + nIdx * xySize] - 1.0f)); + xyIdx += xyBlockDim; + } + } + __syncthreads(); + + for(int s = xyBlockDim>>1; s > 0; s >>= 1) { + if(xyId < s) { + sumShared[sharedIdx] += sumShared[sharedIdx + cBlockDim * s]; + maxShared[sharedIdx] = fmaxf(maxShared[sharedIdx],maxShared[sharedIdx + cBlockDim * s]); + } + __syncthreads(); + } + if(xyId == 0 && cIdx < cSize) { + float sum = sumShared[sharedIdx]; + float div = maskSum[nIdx]; + float sqrtdiv = sqrt(div); + float mean = sum/div; + + out[cIdx + nIdx * (cSize*3)] = mean; + out[cIdx + nIdx * (cSize*3) + cSize] = mean * (sqrtdiv - 14.0f) * 0.1f; + out[cIdx + nIdx * (cSize*3) + cSize*2] = maxShared[sharedIdx]; + } +} + + +void customCudaPoolRowsSumNHWC(const float* in, float* out, int nSize, int xySize, int cSize, float scaleSum) { + if(nSize > 65536) + throw std::runtime_error("customCudaPoolRowsSumNHWC: nSize too large"); + if(cSize > 65536) + throw std::runtime_error("customCudaPoolRowsSumNHWC: cSize too large"); + + //Use up to two warps worth of threads along the channel dimension, which is the + //most compact + int cThreads = 1; + while(cThreads < 64 && cThreads < cSize/2) + cThreads *= 2; + int cBlocks = (cSize + cThreads - 1) / cThreads; + + //Distribute the extra threads to perform parallel reduction along the xy dimension. + int xyThreads = (targetNumThreads < cThreads) ? 1 : (targetNumThreads / cThreads); + + //We need one shared memory spot per thread + int sharedMemSize = sizeof(float) * cThreads * xyThreads; + + dim3 grid(cBlocks,1,nSize); + dim3 threads(cThreads,xyThreads,1); + sumChannelsNHWCKernel<<>>(in,out,xySize,cSize,scaleSum); +} + +void customCudaValueHeadPoolNHWC(const float* in, float* out, int nSize, int xySize, int cSize, const float* maskSum) { + if(nSize > 65536) + throw std::runtime_error("customCudaValueHeadPoolNHWC: nSize too large"); + if(cSize > 65536) + throw std::runtime_error("customCudaValueHeadPoolNHWC: cSize too large"); + + //Use up to two warps worth of threads along the channel dimension, which is the + //most compact + int cThreads = 1; + while(cThreads < 64 && cThreads < cSize/2) + cThreads *= 2; + int cBlocks = (cSize + cThreads - 1) / cThreads; + + //Distribute the extra threads to perform parallel reduction along the xy dimension. + int xyThreads = (targetNumThreads < cThreads) ? 1 : (targetNumThreads / cThreads); + + //We need one shared memory spot per thread + int sharedMemSize = sizeof(float) * cThreads * xyThreads; + + dim3 grid(cBlocks,1,nSize); + dim3 threads(cThreads,xyThreads,1); + valueHeadPoolChannelsNHWCKernel<<>>(in,out,nSize,xySize,cSize,maskSum); +} + +void customCudaPoolRowsGPoolNHWC(const float* in, float* out, int nSize, int xySize, int cSize, const float* mask, const float* maskSum) { + if(nSize > 65536) + throw std::runtime_error("customCudaPoolRowsGPoolNHWC: nSize too large"); + if(cSize > 65536) + throw std::runtime_error("customCudaPoolRowsGPoolNHWC: cSize too large"); + + //Use up to two warps worth of threads along the channel dimension, which is the + //most compact + int cThreads = 1; + while(cThreads < 64 && cThreads < cSize/2) + cThreads *= 2; + int cBlocks = (cSize + cThreads - 1) / cThreads; + + //Distribute the extra threads to perform parallel reduction along the xy dimension. + int xyThreads = (targetNumThreads < cThreads) ? 1 : (targetNumThreads / cThreads); + + //We need one shared memory spot per thread, and then we double it because we need both sum and max. + //We also make sure it's a power of two to address any alignment concerns. + int sharedMemElts = 128; + while(sharedMemElts < cThreads * xyThreads) + sharedMemElts *= 2; + int sharedMemSize = sizeof(float) * sharedMemElts * 2; + + dim3 grid(cBlocks,1,nSize); + dim3 threads(cThreads,xyThreads,1); + if(mask != NULL) + gPoolChannelsNHWCMaskKernel<<>>(in,out,xySize,cSize,mask,maskSum,sharedMemElts); + else + gPoolChannelsNHWCKernel<<>>(in,out,xySize,cSize,maskSum,sharedMemElts); +} + +//-------------------------------------------------------------------------------------------------------------- + +__global__ +void gPoolChannelsNHWCHalfKernel(const half* in, half* out, int xySize, int cSize, const float* maskSum, int sharedMemElts) +{ +#ifdef HIP_SUPPORTS_FP16 + extern __shared__ float poolNHWCShared[]; + float* sumShared = (float*)poolNHWCShared; + float* maxShared = (float*)poolNHWCShared + sharedMemElts; + + int cId = threadIdx.x; + int cBlockDim = blockDim.x; + int xyId = threadIdx.y; + int xyBlockDim = blockDim.y; + + int cIdx = blockIdx.x * cBlockDim + cId; + int nIdx = blockIdx.z; + int sharedIdx = cId + cBlockDim * xyId; + int xycSize = xySize*cSize; + + sumShared[sharedIdx] = 0; + maxShared[sharedIdx] = -1.0f; + + if(cIdx < cSize) { + int xyIdx = xyId; + while(xyIdx < xySize) { + float a = __half2float(in[cIdx + xyIdx * cSize + nIdx * xycSize]); + sumShared[sharedIdx] += a; + maxShared[sharedIdx] = fmaxf(maxShared[sharedIdx], a); + xyIdx += xyBlockDim; + } + } + __syncthreads(); + + for(int s = xyBlockDim>>1; s > 0; s >>= 1) { + if(xyId < s) { + sumShared[sharedIdx] += sumShared[sharedIdx + cBlockDim * s]; + maxShared[sharedIdx] = fmaxf(maxShared[sharedIdx],maxShared[sharedIdx + cBlockDim * s]); + } + __syncthreads(); + } + if(xyId == 0 && cIdx < cSize) { + float sum = sumShared[sharedIdx]; + float div = maskSum[nIdx]; + float sqrtdiv = sqrt(div); + float mean = sum/div; + + out[cIdx + nIdx * (cSize*3)] = __float2half(mean); + out[cIdx + nIdx * (cSize*3) + cSize] = __float2half(mean * (sqrtdiv - 14.0f) * 0.1f); + out[cIdx + nIdx * (cSize*3) + cSize*2] = __float2half(maxShared[sharedIdx]); + } +#else + //Do nothing, FP16 not supported +#endif +} +__global__ +void gPoolChannelsNHWCHalfMaskKernel(const half* in, half* out, int xySize, int cSize, const half* mask, const float* maskSum, int sharedMemElts) +{ +#ifdef HIP_SUPPORTS_FP16 + extern __shared__ float poolNHWCShared[]; + float* sumShared = (float*)poolNHWCShared; + float* maxShared = (float*)poolNHWCShared + sharedMemElts; + + int cId = threadIdx.x; + int cBlockDim = blockDim.x; + int xyId = threadIdx.y; + int xyBlockDim = blockDim.y; + + int cIdx = blockIdx.x * cBlockDim + cId; + int nIdx = blockIdx.z; + int sharedIdx = cId + cBlockDim * xyId; + int xycSize = xySize*cSize; + + sumShared[sharedIdx] = 0; + maxShared[sharedIdx] = -1.0f; + + if(cIdx < cSize) { + int xyIdx = xyId; + while(xyIdx < xySize) { + float a = __half2float(in[cIdx + xyIdx * cSize + nIdx * xycSize]); + sumShared[sharedIdx] += a; + // Init to -1.0 above and + mask - 1.0 is because it will effectively make all padded space into -1.0 + // which is lower than the lowest value that any current activation function will produce. + // so the max over all valid spaces will the same as the mask over all spaces including padding + maxShared[sharedIdx] = fmaxf(maxShared[sharedIdx], a + (__half2float(mask[xyIdx + nIdx * xySize]) - 1.0f)); + xyIdx += xyBlockDim; + } + } + __syncthreads(); + + for(int s = xyBlockDim>>1; s > 0; s >>= 1) { + if(xyId < s) { + sumShared[sharedIdx] += sumShared[sharedIdx + cBlockDim * s]; + maxShared[sharedIdx] = fmaxf(maxShared[sharedIdx],maxShared[sharedIdx + cBlockDim * s]); + } + __syncthreads(); + } + if(xyId == 0 && cIdx < cSize) { + float sum = sumShared[sharedIdx]; + float div = maskSum[nIdx]; + float sqrtdiv = sqrt(div); + float mean = sum/div; + + out[cIdx + nIdx * (cSize*3)] = __float2half(mean); + out[cIdx + nIdx * (cSize*3) + cSize] = __float2half(mean * (sqrtdiv - 14.0f) * 0.1f); + out[cIdx + nIdx * (cSize*3) + cSize*2] = __float2half(maxShared[sharedIdx]); + } +#else + //Do nothing, FP16 not supported +#endif +} + +void customCudaPoolRowsGPoolNHWC(const half* in, half* out, int nSize, int xySize, int cSize, const half* mask, const float* maskSum) { + if(nSize > 65536) + throw std::runtime_error("customCudaPoolRowsGPoolNHWC: nSize too large"); + if(cSize > 65536) + throw std::runtime_error("customCudaPoolRowsGPoolNHWC: cSize too large"); + + //Use up to two warps worth of threads along the channel dimension, which is the + //most compact + int cThreads = 1; + while(cThreads < 64 && cThreads < cSize/2) + cThreads *= 2; + int cBlocks = (cSize + cThreads - 1) / cThreads; + + //Distribute the extra threads to perform parallel reduction along the xy dimension. + int xyThreads = (targetNumThreads < cThreads) ? 1 : (targetNumThreads / cThreads); + + //We need one shared memory spot per thread, and then we double it because we need both sum and max. + //We also make sure it's a power of two to address any alignment concerns. + int sharedMemElts = 128; + while(sharedMemElts < cThreads * xyThreads) + sharedMemElts *= 2; + int sharedMemSize = sizeof(float) * sharedMemElts * 2; + + dim3 grid(cBlocks,1,nSize); + dim3 threads(cThreads,xyThreads,1); + if(mask != NULL) + gPoolChannelsNHWCHalfMaskKernel<<>>(in,out,xySize,cSize,mask,maskSum,sharedMemElts); + else + gPoolChannelsNHWCHalfKernel<<>>(in,out,xySize,cSize,maskSum,sharedMemElts); +} + + +//-------------------------------------------------------------------------------------------------------------- + +__global__ +void copyToHalfKernel(const float *in, half* out, int n) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx < n) { + out[idx] = __float2half(in[idx]); + } +} +__global__ +void copyFromHalfKernel(const half *in, float* out, int n) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx < n) { + out[idx] = __half2float(in[idx]); + } +} + +void customCudaCopyToHalf(const float* in, half* out, int n) { + int blockSize = targetNumThreads; + int numBlocks = (n+blockSize-1)/blockSize; + copyToHalfKernel<<>>(in,out,n); +} +void customCudaCopyFromHalf(const half* in, float* out, int n) { + int blockSize = targetNumThreads; + int numBlocks = (n+blockSize-1)/blockSize; + copyFromHalfKernel<<>>(in,out,n); +} + +//-------------------------------------------------------------------------------------------------------------- + + +//-------------------------------------------------------------------------------------------------------------- +// Element-wise tensor add: buf[i] += toAdd[i], for float and half + +__global__ +void addTensorsInplaceKernel(float *buf, const float* toAdd, int nSize) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx < nSize) { + buf[idx] += toAdd[idx]; + } +} +void customCudaAddTensorsInplace(float* buf, const float* toAdd, int nSize) { + int blockSize = targetNumThreads; + int numBlocks = (nSize+blockSize-1)/blockSize; + addTensorsInplaceKernel<<>>(buf,toAdd,nSize); +} + +__global__ +void addTensorsInplaceHalfKernel(half *buf, const half* toAdd, int nSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx < nSize) { + buf[idx] = __hadd(buf[idx],toAdd[idx]); + } +#else + //Do nothing, FP16 not supported +#endif +} +void customCudaAddTensorsInplace(half* buf, const half* toAdd, int nSize) { + int blockSize = targetNumThreads; + int numBlocks = (nSize+blockSize-1)/blockSize; + addTensorsInplaceHalfKernel<<>>(buf,toAdd,nSize); +} + +//-------------------------------------------------------------------------------------------------------------- + +__global__ +void addTensorInplaceHalfKernel(half *buf, const half* biases, int nSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx < nSize) { + buf[idx] = __hadd(buf[idx],biases[idx]); + } +#else + //Do nothing, FP16 not supported +#endif +} +void customCudaAddTensorInplace(half* buf, const half* biases, int nSize) { + int blockSize = targetNumThreads; + int numBlocks = (nSize+blockSize-1)/blockSize; + addTensorInplaceHalfKernel<<>>(buf,biases,nSize); +} + +//-------------------------------------------------------------------------------------------------------------- + + +__global__ +void addCBiasInplaceNCKernel(float *buf, const float* biases, int nSize, int cSize) +{ + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int nIdx = blockIdx.y * blockDim.y + threadIdx.y; + if(cIdx < cSize && nIdx < nSize) { + int idx = nIdx * cSize + cIdx; + buf[idx] = buf[idx] + biases[cIdx]; + } +} +__global__ +void addCBiasInplaceNCHalfKernel(half *buf, const half* biases, int nSize, int cSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int nIdx = blockIdx.y * blockDim.y + threadIdx.y; + if(cIdx < cSize && nIdx < nSize) { + int idx = nIdx * cSize + cIdx; + buf[idx] = __hadd(buf[idx],biases[cIdx]); + } +#else + //Do nothing, FP16 not supported +#endif +} + +__global__ +void addCBiasInplaceNCKernelRelu(float *buf, const float* biases, int nSize, int cSize) +{ + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int nIdx = blockIdx.y * blockDim.y + threadIdx.y; + if(cIdx < cSize && nIdx < nSize) { + int idx = nIdx * cSize + cIdx; + buf[idx] = fmaxf(buf[idx] + biases[cIdx],0.0f); + } +} +__global__ +void addCBiasInplaceNCHalfKernelRelu(half *buf, const half* biases, int nSize, int cSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int nIdx = blockIdx.y * blockDim.y + threadIdx.y; + if(cIdx < cSize && nIdx < nSize) { + int idx = nIdx * cSize + cIdx; + const half halfzero = __float2half(0.0f); + half a = __hadd(buf[idx],biases[cIdx]); + buf[idx] = __hgt(a,halfzero) ? a : halfzero; + } +#else + //Do nothing, FP16 not supported +#endif +} + +__global__ +void addCBiasInplaceNCKernelMish(float *buf, const float* biases, int nSize, int cSize) +{ + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int nIdx = blockIdx.y * blockDim.y + threadIdx.y; + if(cIdx < cSize && nIdx < nSize) { + int idx = nIdx * cSize + cIdx; + buf[idx] = mishf(buf[idx] + biases[cIdx]); + } +} +__global__ +void addCBiasInplaceNCHalfKernelMish(half *buf, const half* biases, int nSize, int cSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int nIdx = blockIdx.y * blockDim.y + threadIdx.y; + if(cIdx < cSize && nIdx < nSize) { + int idx = nIdx * cSize + cIdx; + half a = __hadd(buf[idx],biases[cIdx]); + buf[idx] = mishh(a); + } +#else + //Do nothing, FP16 not supported +#endif +} +__global__ +void addCBiasInplaceNCKernelMishScale8(float *buf, const float* biases, int nSize, int cSize) +{ + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int nIdx = blockIdx.y * blockDim.y + threadIdx.y; + if(cIdx < cSize && nIdx < nSize) { + int idx = nIdx * cSize + cIdx; + buf[idx] = mishf_scale8(buf[idx] + biases[cIdx]); + } +} +__global__ +void addCBiasInplaceNCHalfKernelMishScale8(half *buf, const half* biases, int nSize, int cSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int nIdx = blockIdx.y * blockDim.y + threadIdx.y; + if(cIdx < cSize && nIdx < nSize) { + int idx = nIdx * cSize + cIdx; + half a = __hadd(buf[idx],biases[cIdx]); + buf[idx] = mishh_scale8(a); + } +#else + //Do nothing, FP16 not supported +#endif +} + +void sharedAddCBiasInplaceNC(void* buf, const void* biases, int nSize, int cSize, bool isHalf, int activation) { + int cThreads; + int cBlocks; + int nThreads; + int nBlocks; + splitThreadsAcrossDim01(cSize, nSize, cThreads, cBlocks, nThreads, nBlocks); + + if(nBlocks > 65536) + throw std::runtime_error("customCudaAddCBiasInplaceNC: nSize too large given cSize"); + + dim3 grid(cBlocks,nBlocks,1); + dim3 threads(cThreads,nThreads,1); + + if(activation == ACTIVATION_IDENTITY) { + if(isHalf) + addCBiasInplaceNCHalfKernel<<>>((half*)buf,(const half*)biases,nSize,cSize); + else + addCBiasInplaceNCKernel<<>>((float*)buf,(const float*)biases,nSize,cSize); + } + else if(activation == ACTIVATION_RELU) { + if(isHalf) + addCBiasInplaceNCHalfKernelRelu<<>>((half*)buf,(const half*)biases,nSize,cSize); + else + addCBiasInplaceNCKernelRelu<<>>((float*)buf,(const float*)biases,nSize,cSize); + } + else if(activation == ACTIVATION_MISH) { + if(isHalf) + addCBiasInplaceNCHalfKernelMish<<>>((half*)buf,(const half*)biases,nSize,cSize); + else + addCBiasInplaceNCKernelMish<<>>((float*)buf,(const float*)biases,nSize,cSize); + } + else if(activation == ACTIVATION_MISH_SCALE8) { + if(isHalf) + addCBiasInplaceNCHalfKernelMishScale8<<>>((half*)buf,(const half*)biases,nSize,cSize); + else + addCBiasInplaceNCKernelMishScale8<<>>((float*)buf,(const float*)biases,nSize,cSize); + } + else { + throw std::runtime_error("customCudaAddCBiasInplaceNC: unsupported activation"); + } +} + +void customCudaAddCBiasInplaceNC(float* buf, const float* biases, int nSize, int cSize, int activation) { + sharedAddCBiasInplaceNC(buf,biases,nSize,cSize,false,activation); +} +void customCudaAddCBiasInplaceNC(half* buf, const half* biases, int nSize, int cSize, int activation) { + sharedAddCBiasInplaceNC(buf,biases,nSize,cSize,true,activation); +} + +//-------------------------------------------------------------------------------------------------------------- + +__global__ +void addNCBiasInplaceNCHWKernel(float *buf, const float* biases, int cSize, int sSize) +{ + int sIdx = blockIdx.x * blockDim.x + threadIdx.x; + int cIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int ncIdx = nIdx * cSize + cIdx; + int idx = ncIdx * sSize + sIdx; + buf[idx] = buf[idx] + biases[ncIdx]; + } +} +__global__ +void addNCBiasInplaceNCHWHalfKernel(half *buf, const half* biases, int cSize, int sSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int sIdx = blockIdx.x * blockDim.x + threadIdx.x; + int cIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int ncIdx = nIdx * cSize + cIdx; + int idx = ncIdx * sSize + sIdx; + buf[idx] = __hadd(buf[idx],biases[ncIdx]); + } +#else + //Do nothing, FP16 not supported +#endif +} + +void sharedAddNCBiasInplaceNCHW(void *buf, const void* biases, int nSize, int cSize, int xySize, bool isHalf) { + if(nSize > 65536) + throw std::runtime_error("customCudaAddNCBiasInplaceNCHW: nSize too large"); + if(cSize > 65536) + throw std::runtime_error("customCudaAddNCBiasInplaceNCHW: cSize too large"); + + int sSize = xySize; + int sThreads; + int sBlocks; + int cThreads; + int cBlocks; + splitThreadsAcrossDim01(sSize, cSize, sThreads, sBlocks, cThreads, cBlocks); + + dim3 grid(sBlocks,cBlocks,nSize); + dim3 threads(sThreads,cThreads,1); + if(isHalf) + addNCBiasInplaceNCHWHalfKernel<<>>((half*)buf,(const half*)biases,cSize,sSize); + else + addNCBiasInplaceNCHWKernel<<>>((float*)buf,(const float*)biases,cSize,sSize); +} + +void customCudaAddNCBiasInplaceNCHW(float *buf, const float* biases, int nSize, int cSize, int xySize) { + sharedAddNCBiasInplaceNCHW(buf,biases,nSize,cSize,xySize,false); +} +void customCudaAddNCBiasInplaceNCHW(half *buf, const half* biases, int nSize, int cSize, int xySize) { + sharedAddNCBiasInplaceNCHW(buf,biases,nSize,cSize,xySize,true); +} + +//-------------------------------------------------------------------------------------------------------------- + +__global__ +void addNCBiasInplaceNHWCKernel(float *buf, const float* biases, int sSize, int cSize) +{ + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int sIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int ncIdx = nIdx * cSize + cIdx; + int idx = (nIdx * sSize + sIdx) * cSize + cIdx; + buf[idx] = buf[idx] + biases[ncIdx]; + } +} +__global__ +void addNCBiasInplaceNHWCHalfKernel(half *buf, const half* biases, int sSize, int cSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int sIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int ncIdx = nIdx * cSize + cIdx; + int idx = (nIdx * sSize + sIdx) * cSize + cIdx; + buf[idx] = __hadd(buf[idx],biases[ncIdx]); + } +#else + //Do nothing, FP16 not supported +#endif +} + +void sharedAddNCBiasInplaceNHWC(void *buf, const void* biases, int nSize, int xySize, int cSize, bool isHalf) { + if(nSize > 65536) + throw std::runtime_error("customCudaAddNCBiasInplaceNHWC: nSize too large"); + if(xySize > 65536) + throw std::runtime_error("customCudaAddNCBiasInplaceNHWC: xySize too large"); + + int sSize = xySize; + int cThreads; + int cBlocks; + int sThreads; + int sBlocks; + splitThreadsAcrossDim01(cSize, sSize, cThreads, cBlocks, sThreads, sBlocks); + + dim3 grid(cBlocks,sBlocks,nSize); + dim3 threads(cThreads,sThreads,1); + if(isHalf) + addNCBiasInplaceNHWCHalfKernel<<>>((half*)buf,(const half*)biases,sSize,cSize); + else + addNCBiasInplaceNHWCKernel<<>>((float*)buf,(const float*)biases,sSize,cSize); +} + +void customCudaAddNCBiasInplaceNHWC(float *buf, const float* biases, int nSize, int xySize, int cSize) { + sharedAddNCBiasInplaceNHWC(buf,biases,nSize,xySize,cSize,false); +} +void customCudaAddNCBiasInplaceNHWC(half *buf, const half* biases, int nSize, int xySize, int cSize) { + sharedAddNCBiasInplaceNHWC(buf,biases,nSize,xySize,cSize,true); +} + +//-------------------------------------------------------------------------------------------------------------- + +__global__ +void applyCScaleBiasNCHWKernel(const float *in, float* out, const float* scale, const float* biases, int cSize, int sSize) +{ + int sIdx = blockIdx.x * blockDim.x + threadIdx.x; + int cIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * cSize + cIdx) * sSize + sIdx; + out[idx] = in[idx] * scale[cIdx] + biases[cIdx]; + } +} +__global__ +void applyCScaleBiasNCHWReluKernel(const float *in, float* out, const float* scale, const float* biases, int cSize, int sSize) +{ + int sIdx = blockIdx.x * blockDim.x + threadIdx.x; + int cIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * cSize + cIdx) * sSize + sIdx; + out[idx] = fmaxf(in[idx] * scale[cIdx] + biases[cIdx],0.0f); + } +} +__global__ +void applyCScaleBiasNCHWMishKernel(const float *in, float* out, const float* scale, const float* biases, int cSize, int sSize) +{ + int sIdx = blockIdx.x * blockDim.x + threadIdx.x; + int cIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * cSize + cIdx) * sSize + sIdx; + out[idx] = mishf(in[idx] * scale[cIdx] + biases[cIdx]); + } +} +__global__ +void applyCScaleBiasNCHWMishScale8Kernel(const float *in, float* out, const float* scale, const float* biases, int cSize, int sSize) +{ + int sIdx = blockIdx.x * blockDim.x + threadIdx.x; + int cIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * cSize + cIdx) * sSize + sIdx; + out[idx] = mishf_scale8(in[idx] * scale[cIdx] + biases[cIdx]); + } +} +__global__ +void applyCScaleBiasNCHWMaskKernel(const float *in, float* out, const float* scale, const float* biases, const float* mask, int cSize, int sSize) +{ + int sIdx = blockIdx.x * blockDim.x + threadIdx.x; + int cIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * cSize + cIdx) * sSize + sIdx; + out[idx] = (in[idx] * scale[cIdx] + biases[cIdx]) * mask[nIdx*sSize+sIdx]; + } +} +__global__ +void applyCScaleBiasNCHWReluMaskKernel(const float *in, float* out, const float* scale, const float* biases, const float* mask, int cSize, int sSize) +{ + int sIdx = blockIdx.x * blockDim.x + threadIdx.x; + int cIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * cSize + cIdx) * sSize + sIdx; + out[idx] = fmaxf(in[idx] * scale[cIdx] + biases[cIdx],0.0f) * mask[nIdx*sSize+sIdx]; + } +} +__global__ +void applyCScaleBiasNCHWMishMaskKernel(const float *in, float* out, const float* scale, const float* biases, const float* mask, int cSize, int sSize) +{ + int sIdx = blockIdx.x * blockDim.x + threadIdx.x; + int cIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * cSize + cIdx) * sSize + sIdx; + out[idx] = mishf(in[idx] * scale[cIdx] + biases[cIdx]) * mask[nIdx*sSize+sIdx]; + } +} +__global__ +void applyCScaleBiasNCHWMishScale8MaskKernel(const float *in, float* out, const float* scale, const float* biases, const float* mask, int cSize, int sSize) +{ + int sIdx = blockIdx.x * blockDim.x + threadIdx.x; + int cIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * cSize + cIdx) * sSize + sIdx; + out[idx] = mishf_scale8(in[idx] * scale[cIdx] + biases[cIdx]) * mask[nIdx*sSize+sIdx]; + } +} +__global__ +void applyCScaleBiasNCHWHalfKernel(const half *in, half* out, const half* scale, const half* biases, int cSize, int sSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int sIdx = blockIdx.x * blockDim.x + threadIdx.x; + int cIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * cSize + cIdx) * sSize + sIdx; + out[idx] = __hfma(in[idx],scale[cIdx],biases[cIdx]); + } +#else + //Do nothing, FP16 not supported +#endif +} +__global__ +void applyCScaleBiasNCHWReluHalfKernel(const half *in, half* out, const half* scale, const half* biases, int cSize, int sSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int sIdx = blockIdx.x * blockDim.x + threadIdx.x; + int cIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * cSize + cIdx) * sSize + sIdx; + half a = __hfma(in[idx],scale[cIdx],biases[cIdx]); + const half halfzero = __float2half(0.0f); + out[idx] = __hgt(a,halfzero) ? a : halfzero; + } +#else + //Do nothing, FP16 not supported +#endif +} +__global__ +void applyCScaleBiasNCHWMishHalfKernel(const half *in, half* out, const half* scale, const half* biases, int cSize, int sSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int sIdx = blockIdx.x * blockDim.x + threadIdx.x; + int cIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * cSize + cIdx) * sSize + sIdx; + half a = __hfma(in[idx],scale[cIdx],biases[cIdx]); + out[idx] = mishh(a); + } +#else + //Do nothing, FP16 not supported +#endif +} +__global__ +void applyCScaleBiasNCHWMishScale8HalfKernel(const half *in, half* out, const half* scale, const half* biases, int cSize, int sSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int sIdx = blockIdx.x * blockDim.x + threadIdx.x; + int cIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * cSize + cIdx) * sSize + sIdx; + half a = __hfma(in[idx],scale[cIdx],biases[cIdx]); + out[idx] = mishh_scale8(a); + } +#else + //Do nothing, FP16 not supported +#endif +} +__global__ +void applyCScaleBiasNCHWMaskHalfKernel(const half *in, half* out, const half* scale, const half* biases, const half* mask, int cSize, int sSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int sIdx = blockIdx.x * blockDim.x + threadIdx.x; + int cIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * cSize + cIdx) * sSize + sIdx; + out[idx] = __hmul(__hfma(in[idx],scale[cIdx],biases[cIdx]),mask[nIdx*sSize+sIdx]); + } +#else + //Do nothing, FP16 not supported +#endif +} +__global__ +void applyCScaleBiasNCHWReluMaskHalfKernel(const half *in, half* out, const half* scale, const half* biases, const half* mask, int cSize, int sSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int sIdx = blockIdx.x * blockDim.x + threadIdx.x; + int cIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * cSize + cIdx) * sSize + sIdx; + half a = __hmul(__hfma(in[idx],scale[cIdx],biases[cIdx]),mask[nIdx*sSize+sIdx]); + const half halfzero = __float2half(0.0f); + out[idx] = __hgt(a,halfzero) ? a : halfzero; + } +#else + //Do nothing, FP16 not supported +#endif +} +__global__ +void applyCScaleBiasNCHWMishMaskHalfKernel(const half *in, half* out, const half* scale, const half* biases, const half* mask, int cSize, int sSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int sIdx = blockIdx.x * blockDim.x + threadIdx.x; + int cIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * cSize + cIdx) * sSize + sIdx; + half a = __hmul(__hfma(in[idx],scale[cIdx],biases[cIdx]),mask[nIdx*sSize+sIdx]); + out[idx] = mishh(a); + } +#else + //Do nothing, FP16 not supported +#endif +} +__global__ +void applyCScaleBiasNCHWMishScale8MaskHalfKernel(const half *in, half* out, const half* scale, const half* biases, const half* mask, int cSize, int sSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int sIdx = blockIdx.x * blockDim.x + threadIdx.x; + int cIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * cSize + cIdx) * sSize + sIdx; + half a = __hmul(__hfma(in[idx],scale[cIdx],biases[cIdx]),mask[nIdx*sSize+sIdx]); + out[idx] = mishh_scale8(a); + } +#else + //Do nothing, FP16 not supported +#endif +} + +void sharedApplyCScaleBiasNCHW(const void* in, void* out, const void* scale, const void* biases, const void* mask, int nSize, int cSize, int xySize, bool isHalf, int activation) { + if(nSize > 65536) + throw std::runtime_error("customCudaApplyCScaleBiasNCHW: nSize too large"); + if(cSize > 65536) + throw std::runtime_error("customCudaApplyCScaleBiasNCHW: cSize too large"); + + int sSize = xySize; + int sThreads; + int sBlocks; + int cThreads; + int cBlocks; + splitThreadsAcrossDim01(sSize, cSize, sThreads, sBlocks, cThreads, cBlocks); + + dim3 grid(sBlocks,cBlocks,nSize); + dim3 threads(sThreads,cThreads,1); + if(mask == NULL) { + if(activation == ACTIVATION_IDENTITY) { + if(isHalf) + applyCScaleBiasNCHWHalfKernel<<>>((const half*)in,(half*)out,(const half*)scale,(const half*)biases,cSize,sSize); + else + applyCScaleBiasNCHWKernel<<>>((const float*)in,(float*)out,(const float*)scale,(const float*)biases,cSize,sSize); + } + else if(activation == ACTIVATION_RELU) { + if(isHalf) + applyCScaleBiasNCHWReluHalfKernel<<>>((const half*)in,(half*)out,(const half*)scale,(const half*)biases,cSize,sSize); + else + applyCScaleBiasNCHWReluKernel<<>>((const float*)in,(float*)out,(const float*)scale,(const float*)biases,cSize,sSize); + } + else if(activation == ACTIVATION_MISH) { + if(isHalf) + applyCScaleBiasNCHWMishHalfKernel<<>>((const half*)in,(half*)out,(const half*)scale,(const half*)biases,cSize,sSize); + else + applyCScaleBiasNCHWMishKernel<<>>((const float*)in,(float*)out,(const float*)scale,(const float*)biases,cSize,sSize); + } + else if(activation == ACTIVATION_MISH_SCALE8) { + if(isHalf) + applyCScaleBiasNCHWMishScale8HalfKernel<<>>((const half*)in,(half*)out,(const half*)scale,(const half*)biases,cSize,sSize); + else + applyCScaleBiasNCHWMishScale8Kernel<<>>((const float*)in,(float*)out,(const float*)scale,(const float*)biases,cSize,sSize); + } + else { + throw std::runtime_error("customCudaApplyCScaleBiasNCHW: unsupported activation"); + } + } + else { + if(activation == ACTIVATION_IDENTITY) { + if(isHalf) + applyCScaleBiasNCHWMaskHalfKernel<<>>((const half*)in,(half*)out,(const half*)scale,(const half*)biases,(const half*)mask,cSize,sSize); + else + applyCScaleBiasNCHWMaskKernel<<>>((const float*)in,(float*)out,(const float*)scale,(const float*)biases,(const float*)mask,cSize,sSize); + } + else if(activation == ACTIVATION_RELU) { + if(isHalf) + applyCScaleBiasNCHWReluMaskHalfKernel<<>>((const half*)in,(half*)out,(const half*)scale,(const half*)biases,(const half*)mask,cSize,sSize); + else + applyCScaleBiasNCHWReluMaskKernel<<>>((const float*)in,(float*)out,(const float*)scale,(const float*)biases,(const float*)mask,cSize,sSize); + } + else if(activation == ACTIVATION_MISH) { + if(isHalf) + applyCScaleBiasNCHWMishMaskHalfKernel<<>>((const half*)in,(half*)out,(const half*)scale,(const half*)biases,(const half*)mask,cSize,sSize); + else + applyCScaleBiasNCHWMishMaskKernel<<>>((const float*)in,(float*)out,(const float*)scale,(const float*)biases,(const float*)mask,cSize,sSize); + } + else if(activation == ACTIVATION_MISH_SCALE8) { + if(isHalf) + applyCScaleBiasNCHWMishScale8MaskHalfKernel<<>>((const half*)in,(half*)out,(const half*)scale,(const half*)biases,(const half*)mask,cSize,sSize); + else + applyCScaleBiasNCHWMishScale8MaskKernel<<>>((const float*)in,(float*)out,(const float*)scale,(const float*)biases,(const float*)mask,cSize,sSize); + } + else { + throw std::runtime_error("customCudaApplyCScaleBiasNCHW: unsupported activation"); + } + } +} + +void customCudaApplyCScaleBiasNCHW(const float* in, float* out, const float* scale, const float* biases, const float* mask, int nSize, int cSize, int xySize, int activation) { + sharedApplyCScaleBiasNCHW(in,out,scale,biases,mask,nSize,cSize,xySize,false,activation); +} +void customCudaApplyCScaleBiasNCHW(const half* in, half* out, const half* scale, const half* biases, const half* mask, int nSize, int cSize, int xySize, int activation) { + sharedApplyCScaleBiasNCHW(in,out,scale,biases,mask,nSize,cSize,xySize,true,activation); +} + + +//-------------------------------------------------------------------------------------------------------------- + +__global__ +void applyCScaleBiasNHWCKernel(const float* in, float* out, const float* scale, const float* biases, int sSize, int cSize) +{ + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int sIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * sSize + sIdx) * cSize + cIdx; + out[idx] = in[idx] * scale[cIdx] + biases[cIdx]; + } +} +__global__ +void applyCScaleBiasNHWCReluKernel(const float* in, float* out, const float* scale, const float* biases, int sSize, int cSize) +{ + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int sIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * sSize + sIdx) * cSize + cIdx; + out[idx] = fmaxf(in[idx] * scale[cIdx] + biases[cIdx],0.0f); + } +} +__global__ +void applyCScaleBiasNHWCMishKernel(const float* in, float* out, const float* scale, const float* biases, int sSize, int cSize) +{ + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int sIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * sSize + sIdx) * cSize + cIdx; + out[idx] = mishf(in[idx] * scale[cIdx] + biases[cIdx]); + } +} +__global__ +void applyCScaleBiasNHWCMishScale8Kernel(const float* in, float* out, const float* scale, const float* biases, int sSize, int cSize) +{ + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int sIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * sSize + sIdx) * cSize + cIdx; + out[idx] = mishf_scale8(in[idx] * scale[cIdx] + biases[cIdx]); + } +} +__global__ +void applyCScaleBiasNHWCMaskKernel(const float* in, float* out, const float* scale, const float* biases, const float* mask, int sSize, int cSize) +{ + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int sIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * sSize + sIdx) * cSize + cIdx; + out[idx] = (in[idx] * scale[cIdx] + biases[cIdx]) * mask[nIdx*sSize+sIdx]; + } +} +__global__ +void applyCScaleBiasNHWCReluMaskKernel(const float* in, float* out, const float* scale, const float* biases, const float* mask, int sSize, int cSize) +{ + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int sIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * sSize + sIdx) * cSize + cIdx; + out[idx] = fmaxf(in[idx] * scale[cIdx] + biases[cIdx],0.0f) * mask[nIdx*sSize+sIdx]; + } +} +__global__ +void applyCScaleBiasNHWCMishMaskKernel(const float* in, float* out, const float* scale, const float* biases, const float* mask, int sSize, int cSize) +{ + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int sIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * sSize + sIdx) * cSize + cIdx; + out[idx] = mishf(in[idx] * scale[cIdx] + biases[cIdx]) * mask[nIdx*sSize+sIdx]; + } +} +__global__ +void applyCScaleBiasNHWCMishScale8MaskKernel(const float* in, float* out, const float* scale, const float* biases, const float* mask, int sSize, int cSize) +{ + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int sIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * sSize + sIdx) * cSize + cIdx; + out[idx] = mishf_scale8(in[idx] * scale[cIdx] + biases[cIdx]) * mask[nIdx*sSize+sIdx]; + } +} +__global__ +void applyCScaleBiasNHWCHalfKernel(const half* in, half* out, const half* scale, const half* biases, int sSize, int cSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int sIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * sSize + sIdx) * cSize + cIdx; + out[idx] = __hfma(in[idx],scale[cIdx],biases[cIdx]); + } +#else + //Do nothing, FP16 not supported +#endif +} +__global__ +void applyCScaleBiasNHWCReluHalfKernel(const half* in, half* out, const half* scale, const half* biases, int sSize, int cSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int sIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * sSize + sIdx) * cSize + cIdx; + half a = __hfma(in[idx],scale[cIdx],biases[cIdx]); + const half halfzero = __float2half(0.0f); + out[idx] = __hgt(a,halfzero) ? a : halfzero; + } +#else + //Do nothing, FP16 not supported +#endif +} +__global__ +void applyCScaleBiasNHWCMishHalfKernel(const half* in, half* out, const half* scale, const half* biases, int sSize, int cSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int sIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * sSize + sIdx) * cSize + cIdx; + half a = __hfma(in[idx],scale[cIdx],biases[cIdx]); + out[idx] = mishh(a); + } +#else + //Do nothing, FP16 not supported +#endif +} +__global__ +void applyCScaleBiasNHWCMishScale8HalfKernel(const half* in, half* out, const half* scale, const half* biases, int sSize, int cSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int sIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * sSize + sIdx) * cSize + cIdx; + half a = __hfma(in[idx],scale[cIdx],biases[cIdx]); + out[idx] = mishh_scale8(a); + } +#else + //Do nothing, FP16 not supported +#endif +} +__global__ +void applyCScaleBiasNHWCMaskHalfKernel(const half* in, half* out, const half* scale, const half* biases, const half* mask, int sSize, int cSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int sIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * sSize + sIdx) * cSize + cIdx; + out[idx] = __hmul(__hfma(in[idx],scale[cIdx],biases[cIdx]),mask[nIdx*sSize+sIdx]); + } +#else + //Do nothing, FP16 not supported +#endif +} +__global__ +void applyCScaleBiasNHWCReluMaskHalfKernel(const half* in, half* out, const half* scale, const half* biases, const half* mask, int sSize, int cSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int sIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * sSize + sIdx) * cSize + cIdx; + half a = __hmul(__hfma(in[idx],scale[cIdx],biases[cIdx]),mask[nIdx*sSize+sIdx]); + const half halfzero = __float2half(0.0f); + out[idx] = __hgt(a,halfzero) ? a : halfzero; + } +#else + //Do nothing, FP16 not supported +#endif +} +__global__ +void applyCScaleBiasNHWCMishMaskHalfKernel(const half* in, half* out, const half* scale, const half* biases, const half* mask, int sSize, int cSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int sIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * sSize + sIdx) * cSize + cIdx; + half a = __hmul(__hfma(in[idx],scale[cIdx],biases[cIdx]),mask[nIdx*sSize+sIdx]); + out[idx] = mishh(a); + } +#else + //Do nothing, FP16 not supported +#endif +} +__global__ +void applyCScaleBiasNHWCMishScale8MaskHalfKernel(const half* in, half* out, const half* scale, const half* biases, const half* mask, int sSize, int cSize) +{ +#ifdef HIP_SUPPORTS_FP16 + int cIdx = blockIdx.x * blockDim.x + threadIdx.x; + int sIdx = blockIdx.y * blockDim.y + threadIdx.y; + int nIdx = blockIdx.z; + if(cIdx < cSize && sIdx < sSize) { + int idx = (nIdx * sSize + sIdx) * cSize + cIdx; + half a = __hmul(__hfma(in[idx],scale[cIdx],biases[cIdx]),mask[nIdx*sSize+sIdx]); + out[idx] = mishh_scale8(a); + } +#else + //Do nothing, FP16 not supported +#endif +} + +void sharedApplyCScaleBiasNHWC(const void* in, void* out, const void* scale, const void* biases, const void* mask, int nSize, int xySize, int cSize, bool isHalf, int activation) { + if(nSize > 65536) + throw std::runtime_error("customCudaApplyCScaleBiasNHWC: nSize too large"); + if(xySize > 65536) + throw std::runtime_error("customCudaApplyCScaleBiasNHWC: xySize too large"); + + int sSize = xySize; + int cThreads; + int cBlocks; + int sThreads; + int sBlocks; + splitThreadsAcrossDim01(cSize, sSize, cThreads, cBlocks, sThreads, sBlocks); + + dim3 grid(cBlocks,sBlocks,nSize); + dim3 threads(cThreads,sThreads,1); + if(mask == NULL) { + if(activation == ACTIVATION_IDENTITY) { + if(isHalf) + applyCScaleBiasNHWCHalfKernel<<>>((const half*)in,(half*)out,(const half*)scale,(const half*)biases,sSize,cSize); + else + applyCScaleBiasNHWCKernel<<>>((const float*)in,(float*)out,(const float*)scale,(const float*)biases,sSize,cSize); + } + else if(activation == ACTIVATION_RELU) { + if(isHalf) + applyCScaleBiasNHWCReluHalfKernel<<>>((const half*)in,(half*)out,(const half*)scale,(const half*)biases,sSize,cSize); + else + applyCScaleBiasNHWCReluKernel<<>>((const float*)in,(float*)out,(const float*)scale,(const float*)biases,sSize,cSize); + } + else if(activation == ACTIVATION_MISH) { + if(isHalf) + applyCScaleBiasNHWCMishHalfKernel<<>>((const half*)in,(half*)out,(const half*)scale,(const half*)biases,sSize,cSize); + else + applyCScaleBiasNHWCMishKernel<<>>((const float*)in,(float*)out,(const float*)scale,(const float*)biases,sSize,cSize); + } + else if(activation == ACTIVATION_MISH_SCALE8) { + if(isHalf) + applyCScaleBiasNHWCMishScale8HalfKernel<<>>((const half*)in,(half*)out,(const half*)scale,(const half*)biases,sSize,cSize); + else + applyCScaleBiasNHWCMishScale8Kernel<<>>((const float*)in,(float*)out,(const float*)scale,(const float*)biases,sSize,cSize); + } + else { + throw std::runtime_error("customCudaApplyCScaleBiasNHWC: unsupported activation"); + } + } + else { + if(activation == ACTIVATION_IDENTITY) { + if(isHalf) + applyCScaleBiasNHWCMaskHalfKernel<<>>((const half*)in,(half*)out,(const half*)scale,(const half*)biases,(const half*)mask,sSize,cSize); + else + applyCScaleBiasNHWCMaskKernel<<>>((const float*)in,(float*)out,(const float*)scale,(const float*)biases,(const float*)mask,sSize,cSize); + } + else if(activation == ACTIVATION_RELU) { + if(isHalf) + applyCScaleBiasNHWCReluMaskHalfKernel<<>>((const half*)in,(half*)out,(const half*)scale,(const half*)biases,(const half*)mask,sSize,cSize); + else + applyCScaleBiasNHWCReluMaskKernel<<>>((const float*)in,(float*)out,(const float*)scale,(const float*)biases,(const float*)mask,sSize,cSize); + } + else if(activation == ACTIVATION_MISH) { + if(isHalf) + applyCScaleBiasNHWCMishMaskHalfKernel<<>>((const half*)in,(half*)out,(const half*)scale,(const half*)biases,(const half*)mask,sSize,cSize); + else + applyCScaleBiasNHWCMishMaskKernel<<>>((const float*)in,(float*)out,(const float*)scale,(const float*)biases,(const float*)mask,sSize,cSize); + } + else if(activation == ACTIVATION_MISH_SCALE8) { + if(isHalf) + applyCScaleBiasNHWCMishScale8MaskHalfKernel<<>>((const half*)in,(half*)out,(const half*)scale,(const half*)biases,(const half*)mask,sSize,cSize); + else + applyCScaleBiasNHWCMishScale8MaskKernel<<>>((const float*)in,(float*)out,(const float*)scale,(const float*)biases,(const float*)mask,sSize,cSize); + } + else { + throw std::runtime_error("customCudaApplyCScaleBiasNHWC: unsupported activation"); + } + } +} + +void customCudaApplyCScaleBiasNHWC(const float* in, float* out, const float* scale, const float* biases, const float* mask, int nSize, int xySize, int cSize, int activation) { + sharedApplyCScaleBiasNHWC(in,out,scale,biases,mask,nSize,xySize,cSize,false,activation); +} +void customCudaApplyCScaleBiasNHWC(const half* in, half* out, const half* scale, const half* biases, const half* mask, int nSize, int xySize, int cSize, int activation) { + sharedApplyCScaleBiasNHWC(in,out,scale,biases,mask,nSize,xySize,cSize,true,activation); +} diff --git a/cpp/neuralnet/rocmincludes.h b/cpp/neuralnet/rocmincludes.h new file mode 100644 index 000000000..8b494a37e --- /dev/null +++ b/cpp/neuralnet/rocmincludes.h @@ -0,0 +1,15 @@ +#ifndef NEURALNET_ROCMINCLUDES_H +#define NEURALNET_ROCMINCLUDES_H + +//Ensure that CUDA_API_PER_THREAD_DEFAULT_STREAM is always defined +//before any cuda headers are included so that we get the desired threading behavior for CUDA. + +#define CUDA_API_PER_THREAD_DEFAULT_STREAM +#include +#include + +#include +#include + + +#endif //NEURALNET_ROCMINCLUDES_H diff --git a/cpp/neuralnet/rocmutils.cpp b/cpp/neuralnet/rocmutils.cpp new file mode 100644 index 000000000..752298b7f --- /dev/null +++ b/cpp/neuralnet/rocmutils.cpp @@ -0,0 +1,170 @@ +#include "../neuralnet/rocmutils.h" + +#include +#include "../neuralnet/rocmerrorcheck.h" +#include "../neuralnet/rocmincludes.h" +#include "../neuralnet/rocmhelpers.h" + +#include "../external/half-2.2.0/include/half.hpp" + +//------------------------ +#include "../core/using.h" +//------------------------ + +using half_t = half_float::half; + +void CudaUtils::mallocOnDevice(const string& name, int numWeights, void*& deviceBuf, bool useFP16) { + if(useFP16) { + size_t halfBytes = numWeights * sizeof(half_t); + CUDA_ERR(name.c_str(),hipMalloc(&deviceBuf, halfBytes)); + } + else { + size_t floatBytes = numWeights * sizeof(float); + CUDA_ERR(name.c_str(),hipMalloc(&deviceBuf, floatBytes)); + } +} + +void CudaUtils::mallocAndCopyToDevice(const string& name, const vector& weights, void*& deviceBuf, bool useFP16) { + size_t numWeights = weights.size(); + if(useFP16) { + size_t halfBytes = numWeights * sizeof(half_t); + vector weightsHalf(weights.size()); + for(size_t i = 0; i(weights[i]); + CUDA_ERR(name.c_str(),hipMalloc(&deviceBuf, halfBytes)); + CUDA_ERR(name.c_str(),hipMemcpy(deviceBuf, weightsHalf.data(), halfBytes, hipMemcpyHostToDevice)); + } + else { + size_t floatBytes = numWeights * sizeof(float); + CUDA_ERR(name.c_str(),hipMalloc(&deviceBuf, floatBytes)); + CUDA_ERR(name.c_str(),hipMemcpy(deviceBuf, weights.data(), floatBytes, hipMemcpyHostToDevice)); + } +} + +void CudaUtils::mallocAndCopyToDevice(const string& name, const float* weights, int numWeights, void*& deviceBuf, bool useFP16) { + if(useFP16) { + size_t halfBytes = numWeights * sizeof(half_t); + vector weightsHalf(numWeights); + for(int i = 0; i(weights[i]); + CUDA_ERR(name.c_str(),hipMalloc(&deviceBuf, halfBytes)); + CUDA_ERR(name.c_str(),hipMemcpy(deviceBuf, weightsHalf.data(), halfBytes, hipMemcpyHostToDevice)); + } + else { + size_t floatBytes = numWeights * sizeof(float); + CUDA_ERR(name.c_str(),hipMalloc(&deviceBuf, floatBytes)); + CUDA_ERR(name.c_str(),hipMemcpy(deviceBuf, weights, floatBytes, hipMemcpyHostToDevice)); + } +} + +//Only use in testing, allocates an intermediate buffer in the case of FP16 which will be very slow. +void CudaUtils::expensiveCopyFromDevice(const string& name, float* weights, int numWeights, const void* deviceBuf, bool useFP16) { + if(useFP16) { + vector weightsHalf(numWeights); + size_t halfBytes = numWeights * sizeof(half_t); + CUDA_ERR(name.c_str(),hipMemcpy(weightsHalf.data(), deviceBuf, halfBytes, hipMemcpyDeviceToHost)); + for(int i = 0; i values(batchSize * cSize); + expensiveCopyFromDevice(name, values.data(), values.size(), deviceBuf, useFP16); + cout << "=========================================================" << endl; + cout << "TENSOR" << endl; + cout << name << endl; + cout << std::setprecision(8); + int i = 0; + for(int n = 0; n values(batchSize * cSize * xSize * ySize); + expensiveCopyFromDevice(name, values.data(), values.size(), deviceBuf, useFP16); + cout << "=========================================================" << endl; + cout << "TENSOR" << endl; + cout << name << endl; + cout << std::setprecision(8); + int i = 0; + double total1 = 0; + double total2 = 0; + double total3 = 0; + for(int n = 0; n= (int64_t)1 << 31) + throw StringError("Batch size too large, resulting GPU buffers might exceed 2^31 entries which is not currently supported"); +} + +void CudaUtils::hostMallocZeroOneBufs(void*& zeroBuf, void*& oneBuf, bool useFP16) { + if(!useFP16) { + zeroBuf = malloc(sizeof(float)); + oneBuf = malloc(sizeof(float)); + *((float*)zeroBuf) = 0.0f; + *((float*)oneBuf) = 1.0f; + } + else { + //Convert to FP16 on the device, then copy back so we have it in host memory + float zero = 0.0f; + float one = 1.0f; + void* zeroTmp; + void* oneTmp; + mallocAndCopyToDevice("Buffers",&zero,1,zeroTmp,useFP16); + mallocAndCopyToDevice("Buffers",&one,1,oneTmp,useFP16); + zeroBuf = malloc(sizeof(half_t)); + oneBuf = malloc(sizeof(half_t)); + CUDA_ERR("Buffers",hipMemcpy(zeroBuf,zeroTmp,sizeof(half_t),hipMemcpyDeviceToHost)); + CUDA_ERR("Buffers",hipMemcpy(oneBuf,oneTmp,sizeof(half_t),hipMemcpyDeviceToHost)); + hipFree(zeroTmp); + hipFree(oneTmp); + } +} diff --git a/cpp/neuralnet/rocmutils.h b/cpp/neuralnet/rocmutils.h new file mode 100644 index 000000000..d43868402 --- /dev/null +++ b/cpp/neuralnet/rocmutils.h @@ -0,0 +1,21 @@ +#ifndef NEURALNET_ROCMUTILS_H +#define NEURALNET_ROCMUTILS_H + +#include "../core/global.h" + +namespace CudaUtils { + void mallocOnDevice(const std::string& name, int numWeights, void*& deviceBuf, bool useFP16); + void mallocAndCopyToDevice(const std::string& name, const std::vector& weights, void*& deviceBuf, bool useFP16); + void mallocAndCopyToDevice(const std::string& name, const float* weights, int numWeights, void*& deviceBuf, bool useFP16); + + //Only use in testing, allocates an intermediate buffer in the case of FP16 which will be very slow. + void expensiveCopyFromDevice(const std::string& name, float* weights, int numWeights, const void* deviceBuf, bool useFP16); + + void debugPrint2D(const std::string& name, const void* deviceBuf, int batchSize, int cSize, bool useFP16); + void debugPrint4D(const std::string& name, const void* deviceBuf, int batchSize, int cSize, int xSize, int ySize, bool useNHWC, bool useFP16); + + void checkBufferSize(int batchSize, int xSize, int ySize, int channels); + void hostMallocZeroOneBufs(void*& zeroBuf, void*& oneBuf, bool useFP16); +} + +#endif // NEURALNET_ROCMUTILS_H diff --git a/cpp/program/gtpconfig.cpp b/cpp/program/gtpconfig.cpp index 7a45c02de..b03de4618 100644 --- a/cpp/program/gtpconfig.cpp +++ b/cpp/program/gtpconfig.cpp @@ -535,6 +535,12 @@ string GTPConfig::makeConfig( #endif #ifdef USE_OPENCL_BACKEND replacement += "openclDeviceToUseThread" + Global::intToString(i) + " = " + Global::intToString(deviceIdxs[i]) + "\n"; +#endif +#ifdef USE_ROCM_BACKEND + replacement += "rocmDeviceToUseThread" + Global::intToString(i) + " = " + Global::intToString(deviceIdxs[i]) + "\n"; +#endif +#ifdef USE_MIGRAPHX_BACKEND + replacement += "mgxDeviceToUseThread" + Global::intToString(i) + " = " + Global::intToString(deviceIdxs[i]) + "\n"; #endif } replace("$$MULTIPLE_GPUS", replacement); diff --git a/cpp/program/setup.cpp b/cpp/program/setup.cpp index 60baac228..aa798b758 100644 --- a/cpp/program/setup.cpp +++ b/cpp/program/setup.cpp @@ -19,6 +19,8 @@ std::vector Setup::getBackendPrefixes() { prefixes.push_back("trt"); prefixes.push_back("metal"); prefixes.push_back("opencl"); + prefixes.push_back("rocm"); + prefixes.push_back("mgx"); prefixes.push_back("eigen"); prefixes.push_back("dummybackend"); return prefixes; @@ -86,6 +88,10 @@ vector Setup::initializeNNEvaluators( string backendPrefix = "metal"; #elif defined(USE_OPENCL_BACKEND) string backendPrefix = "opencl"; + #elif defined(USE_ROCM_BACKEND) + string backendPrefix = "rocm"; + #elif defined(USE_MIGRAPHX_BACKEND) + string backendPrefix = "mgx"; #elif defined(USE_EIGEN_BACKEND) string backendPrefix = "eigen"; #else @@ -141,7 +147,7 @@ vector Setup::initializeNNEvaluators( requireExactNNLen = cfg.getBool("requireMaxBoardSize"); } - bool inputsUseNHWC = backendPrefix == "opencl" || backendPrefix == "trt" || backendPrefix == "metal" ? false : true; + bool inputsUseNHWC = backendPrefix == "opencl" || backendPrefix == "trt" || backendPrefix == "metal" || backendPrefix == "rocm" || backendPrefix == "mgx" ? false : true; if(cfg.contains(backendPrefix+"InputsUseNHWC"+idxStr)) inputsUseNHWC = cfg.getBool(backendPrefix+"InputsUseNHWC"+idxStr); else if(cfg.contains("inputsUseNHWC"+idxStr))