diff --git a/cmake/QuadrantsCore.cmake b/cmake/QuadrantsCore.cmake index 4e1bd4c606..c1cce89f5c 100644 --- a/cmake/QuadrantsCore.cmake +++ b/cmake/QuadrantsCore.cmake @@ -66,6 +66,7 @@ file(GLOB QUADRANTS_CORE_SOURCE "quadrants/jit/*" "quadrants/math/*" "quadrants/program/*" + "quadrants/program/adstack/*" "quadrants/struct/*" "quadrants/system/*" "quadrants/transforms/*" diff --git a/docs/source/user_guide/autodiff.md b/docs/source/user_guide/autodiff.md index e9885bfb6d..651265c030 100644 --- a/docs/source/user_guide/autodiff.md +++ b/docs/source/user_guide/autodiff.md @@ -350,10 +350,13 @@ A large `ndrange` combined with several loop-carried variables multiplies quickl ## What can go wrong - **Adstack overflow.** Surfaces as `QuadrantsAssertionError: Adstack overflow ...` at the next Quadrants Python entry. The message names the offending kernel + offload task and the most likely cause: - - *Untracked tensor mutation between launches.* A tensor backing a data-dependent loop bound was written to outside Quadrants's tracking - typically a DLPack zero-copy mutation through a torch tensor sharing storage with a Quadrants ndarray, or a raw pointer write through a non-torch consumer. The cached adstack capacity was sized against the value before the mutation; if the mutation grew the bound, the next launch overflows. Fix: route the write through a Quadrants API (`Ndarray.write` / `Ndarray.fill` / a kernel that writes the value). Alternatively, catch the exception and re-launch - Quadrants invalidates the cached bound on raise, so the retry runs against the live state. Kernel state may be inconsistent after an overflow; do not retry the same step without restarting from a clean state. + - *Untracked tensor mutation between launches.* A tensor backing a data-dependent loop bound was written to outside Quadrants's tracking - typically a DLPack zero-copy mutation through a torch tensor sharing storage with a Quadrants ndarray, or a raw pointer write through a non-torch consumer. The cached adstack capacity was sized against the value before the mutation; if the mutation grew the bound, the next launch overflows. Workaround: route the write through a Quadrants API (`Ndarray.write` / `Ndarray.fill` / a kernel that writes the value). Alternatively, catch the exception and re-launch - Quadrants invalidates the cached bound on raise, so the retry runs against the live state. Kernel state may be inconsistent after an overflow; do not retry the same step without restarting from a clean state. - *Sizer under-estimated the bound (Quadrants bug).* On unusually intricate nested loops - typically deeply nested `for i in range(arr[...])` with cumulative-index arithmetic - the sizer can compute a bound that is mathematically tighter than the actual push count. To file a bug: clear `/tmp/ir/`, rerun your script with `QD_DUMP_IR=1` set in the environment so Quadrants dumps the kernel IR there, then open an issue on the Quadrants repo with the contents of `/tmp/ir/` attached as a zip. Workaround: pass a generous `ad_stack_size=N` to `qd.init()` with `N` large enough to cover the real push count (bypasses the sizer). - **Out-of-memory before the kernel even runs.** A reverse pass through many loop-carried variables at a large ndrange can ask the runtime for more adstack memory than the device can physically back, even when the sizer's number is correct. Surfaces as an allocator OOM at launch time. Remedies are the ones listed under *Avoiding OOM on GPU* above: fewer loop-carried variables, a smaller ndrange, manual checkpointing, or more device-memory headroom. - **Loop bounds backed by a mutated ndarray.** A reverse-mode kernel with `for i in range(n[j])` requires `n[j]` to hold the same value at the forward call and at `.grad()`. If anything writes to `n[j]` between those two points - the differentiable kernel itself, or any other kernel call - the backward call will trigger an `Adstack overflow` exception or the computed gradient would come out silently wrong. The safe rule: populate loop-bound ndarrays before the forward call and leave them untouched until `.grad()` returns. The reason for that is Quadrants' adstack sizer design: it reads the loop bound separately at each dispatch, which includes forward and backward calls. Tape-based eager AD like [PyTorch's autograd](https://pytorch.org/docs/stable/notes/autograd.html) is not affected, since the trip count is recorded as the forward runs and reused at backward time. +- **Inner reverse-mode loop with a complex bound at very large extent.** An arbitrarily large enclosing range works only when the inner trip count fits a fixed subset of expressions; other shapes cap at ~16 million enclosing iterations and raise `RuntimeError: ... iteration count ... exceeds the 16777216 guard` past that. Workaround: rewrite the trip count to stay within the supported subset, or shrink the enclosing loop below the threshold. + - *Works at any enclosing-range size:* integer ndarray reads up to 32 bits wide (single- or multi-axis, indexed by literal constants or enclosing loop variables), field reads of the same width indexed by literal constants or enclosing loop variables (`my_field[None]`, `my_field[k]` for a constant `k`, `my_field[i]` where `i` is an enclosing loop variable), `arr.shape[k]` shape terms, literal integer constants, and `+`, `-`, `*`, `max` of those. + - *Caps at the threshold:* 64-bit integer ndarray or field reads, arithmetic-indexed reads (`arr[i // 2]`), and ragged inner ranges whose own bound depends on an enclosing loop variable through an unsupported leaf shape. ## Performance characteristics diff --git a/quadrants/codegen/llvm/codegen_llvm.cpp b/quadrants/codegen/llvm/codegen_llvm.cpp index 6bf252a2d1..6b7a49c91c 100644 --- a/quadrants/codegen/llvm/codegen_llvm.cpp +++ b/quadrants/codegen/llvm/codegen_llvm.cpp @@ -17,6 +17,7 @@ #include "quadrants/codegen/llvm/struct_llvm.h" #include "quadrants/util/file_sequence_writer.h" #include "quadrants/codegen/codegen_utils.h" +#include "quadrants/program/adstack_size_expr_eval.h" #include "llvm/Support/SourceMgr.h" #include "llvm/AsmParser/Parser.h" #include "quadrants/codegen/ir_dump.h" @@ -1993,6 +1994,10 @@ void TaskCodeGenLLVM::finalize_offloaded_task_function() { current_task->ad_stack.allocas = ad_stack_allocas_info_; current_task->ad_stack.size_exprs = ad_stack_size_exprs_; current_task->ad_stack.bound_expr = ad_stack_static_bound_expr_; + // recognize `MaxOverRange` nodes that the runtime can reduce in parallel via the dedicated max-reducer dispatch + // instead of letting the per-thread sizer enumerate. Indexing matches `ad_stack_size_exprs_` (same iteration order + // as the pre-scan above). + current_task->ad_stack.max_reducer_specs = recognize_adstack_max_reducer_specs(ad_stack_size_exprs_); // Snodes the task body mutates. Persisted on `OffloadedTask::snode_writes` so the LLVM // launcher can invalidate the per-task adstack metadata cache when a kernel that runs in // between mutated a SNode an enclosing `size_expr::FieldLoad` reads. Mirrors the SPIR-V diff --git a/quadrants/codegen/llvm/llvm_compiled_data.h b/quadrants/codegen/llvm/llvm_compiled_data.h index 820b8c3476..4da1510113 100644 --- a/quadrants/codegen/llvm/llvm_compiled_data.h +++ b/quadrants/codegen/llvm/llvm_compiled_data.h @@ -81,6 +81,11 @@ struct AdStackSizingInfo { // ids are assigned per `Program` lifetime, not per-kernel-content; a deserialised task re-registers // itself at the next launch. uint32_t registry_id{0}; + // Per-task list of `MaxOverRange` nodes the runtime reduces in parallel via a dedicated max-reducer dispatch (see the + // max-reducer recognizer). Empty when no captured `size_expr` contains a recognized shape. Each entry references one + // alloca's `size_expr` by `(stack_id, mor_node_idx)`; the runtime substitutes the dispatched value as a `Const` into + // the tree before the per-thread sizer walks it. + std::vector max_reducer_specs; QD_IO_DEF(per_thread_stride, per_thread_stride_float, per_thread_stride_int, @@ -92,7 +97,8 @@ struct AdStackSizingInfo { end_offset_bytes, allocas, size_exprs, - bound_expr); + bound_expr, + max_reducer_specs); }; class OffloadedTask { diff --git a/quadrants/codegen/spirv/CMakeLists.txt b/quadrants/codegen/spirv/CMakeLists.txt index 5c57cbc54f..2a3a361ef5 100644 --- a/quadrants/codegen/spirv/CMakeLists.txt +++ b/quadrants/codegen/spirv/CMakeLists.txt @@ -4,6 +4,7 @@ add_library(spirv_codegen) target_sources(spirv_codegen PRIVATE adstack_bound_reducer_shader.cpp + adstack_max_reducer_shader.cpp adstack_sizer_shader.cpp kernel_utils.cpp snode_struct_compiler.cpp diff --git a/quadrants/codegen/spirv/adstack_max_reducer_shader.cpp b/quadrants/codegen/spirv/adstack_max_reducer_shader.cpp new file mode 100644 index 0000000000..59763590e1 --- /dev/null +++ b/quadrants/codegen/spirv/adstack_max_reducer_shader.cpp @@ -0,0 +1,736 @@ +#include "quadrants/codegen/spirv/adstack_max_reducer_shader.h" + +#include "quadrants/codegen/spirv/spirv_ir_builder.h" +#include "quadrants/ir/adstack_size_expr_device.h" +#include "quadrants/ir/type.h" + +namespace quadrants::lang::spirv { + +namespace { + +// Number of u32 words per `AdStackSizeExprDeviceNode` in the bytecode buffer. The host encoder writes the POD directly +// through a `memcpy`-style copy; the shader reads field-by-field at compile-time-known word offsets within the per-node +// 12-word slot. Kept as a single named constant so a future change to the device-node layout (e.g. dropping `_pad0`) +// only needs to touch one place. Mirrors the sizer's per-node walk in `adstack_sizer_shader.cpp` which uses the same +// convention. +constexpr uint32_t kNodeWords = sizeof(AdStackSizeExprDeviceNode) / 4u; +static_assert(sizeof(AdStackSizeExprDeviceNode) % 4u == 0u, + "AdStackSizeExprDeviceNode must be a multiple of 4 bytes for direct u32[] indexing"); + +// Field offsets within an `AdStackSizeExprDeviceNode` slot, in u32 words. Keep in sync with the struct definition in +// `quadrants/ir/adstack_size_expr_device.h`. Read fields by `slot_base_word + kField...` to keep the shader IR +// straight-line and easy to read against the host POD. +constexpr uint32_t kNodeWordKind = 0; +constexpr uint32_t kNodeWordOperandA = 1; +constexpr uint32_t kNodeWordOperandB = 2; +// `kNodeWordVarId = 4` carries the dense device-scope slot for `kBoundVariable`; the host encoder remaps each captured +// chain bound-var id to slot `[0, num_axes)` outermost-first. Read at body interpretation time to index into the +// per-thread scope array. +constexpr uint32_t kNodeWordVarId = 4; +constexpr uint32_t kNodeWordPrimDt = 5; +constexpr uint32_t kNodeWordArgBufferOffset = 6; +constexpr uint32_t kNodeWordIndicesOffset = 7; +constexpr uint32_t kNodeWordIndicesCount = 8; +// `const_value` is a 64-bit field at words 10/11; we always read both halves through `load_buf_i64` which adds 1 to the +// lo word index internally, so we only name the lo offset here. +constexpr uint32_t kNodeWordConstValueLo = 10; + +// Small helper: read one uint32 word from a storage-buffer-backed uint32[] at the given scalar index. Mirrors the +// same-named helper in `adstack_bound_reducer_shader.cpp` and `adstack_sizer_shader.cpp`; kept local to this TU so the +// shader's symbol set stays self-contained and the helper inlines without cross-file linkage. +Value load_buf_u32(IRBuilder &ir, Value buffer, Value word_idx) { + Value ptr = ir.struct_array_access(ir.u32_type(), buffer, word_idx); + return ir.load_variable(ptr, ir.u32_type()); +} + +// Load one i32 word from a storage-buffer-backed uint32[] at scalar index `word_idx`, bitcast-reinterpreted as i32. The +// host encoder writes the PODs verbatim, so signed fields like `operand_a` round-trip through the u32 SSBO via +// little-endian bitcast. Returns an i32 SSA value. +Value load_buf_i32(IRBuilder &ir, Value buffer, Value word_idx) { + Value u = load_buf_u32(ir, buffer, word_idx); + return ir.make_value(spv::OpBitcast, ir.i32_type(), u); +} + +// Load one i64 from two adjacent u32 words at `lo_word_idx` and `lo_word_idx + 1`. Used for `kConst` nodes whose +// `const_value` field straddles two u32 slots in the bytecode buffer (12-word node POD layout). Returned as i64. +Value load_buf_i64(IRBuilder &ir, Value buffer, Value lo_word_idx) { + Value lo = load_buf_u32(ir, buffer, lo_word_idx); + Value hi = load_buf_u32(ir, buffer, ir.add(lo_word_idx, ir.uint_immediate_number(ir.u32_type(), 1u))); + Value lo64 = ir.cast(ir.u64_type(), lo); + Value hi64 = ir.cast(ir.u64_type(), hi); + Value shift = ir.uint_immediate_number(ir.u64_type(), 32u); + Value hi_shifted = ir.make_value(spv::OpShiftLeftLogical, ir.u64_type(), hi64, shift); + Value combined_u64 = ir.make_value(spv::OpBitwiseOr, ir.u64_type(), lo64, hi_shifted); + return ir.make_value(spv::OpBitcast, ir.i64_type(), combined_u64); +} + +// Assemble a u64 from two adjacent little-endian u32 words at `base_word_idx` and `base_word_idx + 1`. Used to +// reconstruct ndarray data pointers from the kernel arg buffer for `kExternalTensorRead`. +Value load_arg_buf_u64_ptr(IRBuilder &ir, Value buffer, Value base_word_idx) { + Value lo = load_buf_u32(ir, buffer, base_word_idx); + Value hi = load_buf_u32(ir, buffer, ir.add(base_word_idx, ir.uint_immediate_number(ir.u32_type(), 1u))); + Value lo64 = ir.cast(ir.u64_type(), lo); + Value hi64 = ir.cast(ir.u64_type(), hi); + Value shift = ir.uint_immediate_number(ir.u64_type(), 32u); + Value hi_shifted = ir.make_value(spv::OpShiftLeftLogical, ir.u64_type(), hi64, shift); + return ir.make_value(spv::OpBitwiseOr, ir.u64_type(), lo64, hi_shifted); +} + +// Physical-Storage-Buffer load of one scalar of `load_ty` width `elem_size` bytes at byte offset `byte_off_u64` from +// `base_u64`. Mirrors the wrapper-struct PSB load pattern in `adstack_sizer_shader.cpp::psb_load_scalar` and +// `adstack_bound_reducer_shader.cpp::psb_load_u32_at_byte_off`. +Value psb_load_scalar_at_byte_off(IRBuilder &ir, + Value base_u64, + Value byte_off_u64, + const SType &load_ty, + uint32_t elem_alignment) { + Value target_u64 = ir.add(base_u64, byte_off_u64); + SType ptr_elem_type = ir.get_pointer_type(load_ty, spv::StorageClassPhysicalStorageBuffer); + std::vector> members = {{load_ty, "_m0", 0}}; + SType wrapper_struct = ir.create_struct_type(members); + SType ptr_struct_type = ir.get_pointer_type(wrapper_struct, spv::StorageClassPhysicalStorageBuffer); + Value struct_ptr = ir.make_value(spv::OpConvertUToPtr, ptr_struct_type, target_u64); + Value scalar_ptr = ir.make_value(spv::OpAccessChain, ptr_elem_type, struct_ptr, ir.const_i32_zero_); + Value scalar = ir.new_value(load_ty, ValueKind::kNormal); + ir.make_inst(spv::OpLoad, load_ty, scalar, scalar_ptr, spv::MemoryAccessAlignedMask, elem_alignment); + return scalar; +} + +// Switch on `prim_dt` (a `PrimitiveTypeID` value) and emit the matching PSB load + sign/zero-extend to i64. The +// recognizer only emits integer leaves (`recognize_adstack_max_reducer_specs` rejects float-typed bodies), so the float +// arms are unreachable and not emitted. Element index is in elements (not bytes); the per-dtype load multiplies by the +// element size internally. +Value emit_psb_load_i64_int_only(IRBuilder &ir, Value data_ptr_u64, Value linear_i32, Value prim_dt_i32) { + Label merge = ir.new_label(); + Label case_i8 = ir.new_label(); + Label case_i16 = ir.new_label(); + Label case_i32 = ir.new_label(); + Label case_i64 = ir.new_label(); + Label case_u8 = ir.new_label(); + Label case_u16 = ir.new_label(); + Label case_u32 = ir.new_label(); + Label case_u64 = ir.new_label(); + Label case_default = ir.new_label(); + + Value linear_u64 = ir.cast(ir.u64_type(), ir.make_value(spv::OpBitcast, ir.u32_type(), linear_i32)); + Value result_var = ir.alloca_variable(ir.i64_type()); + ir.store_variable(result_var, ir.int_immediate_number(ir.i64_type(), 0)); + + ir.make_inst(spv::OpSelectionMerge, merge, spv::SelectionControlMaskNone); + ir.make_inst(spv::OpSwitch, prim_dt_i32, case_default, // + static_cast(PrimitiveTypeID::i8), case_i8, // + static_cast(PrimitiveTypeID::i16), case_i16, // + static_cast(PrimitiveTypeID::i32), case_i32, // + static_cast(PrimitiveTypeID::i64), case_i64, // + static_cast(PrimitiveTypeID::u8), case_u8, // + static_cast(PrimitiveTypeID::u16), case_u16, // + static_cast(PrimitiveTypeID::u32), case_u32, // + static_cast(PrimitiveTypeID::u64), case_u64); + + auto emit_int_case = [&](Label lbl, const SType &load_ty, uint32_t elem_size, bool is_signed) { + ir.start_label(lbl); + Value byte_off = ir.mul(linear_u64, ir.uint_immediate_number(ir.u64_type(), elem_size)); + Value v = psb_load_scalar_at_byte_off(ir, data_ptr_u64, byte_off, load_ty, elem_size); + Value v_i64; + if (load_ty.id == ir.i64_type().id) { + v_i64 = v; + } else if (load_ty.id == ir.u64_type().id) { + v_i64 = ir.make_value(spv::OpBitcast, ir.i64_type(), v); + } else if (is_signed) { + v_i64 = ir.make_value(spv::OpSConvert, ir.i64_type(), v); + } else { + Value v_u64 = ir.make_value(spv::OpUConvert, ir.u64_type(), v); + v_i64 = ir.make_value(spv::OpBitcast, ir.i64_type(), v_u64); + } + ir.store_variable(result_var, v_i64); + ir.make_inst(spv::OpBranch, merge); + }; + + emit_int_case(case_i8, ir.i8_type(), 1u, true); + emit_int_case(case_i16, ir.i16_type(), 2u, true); + emit_int_case(case_i32, ir.i32_type(), 4u, true); + emit_int_case(case_i64, ir.i64_type(), 8u, true); + emit_int_case(case_u8, ir.u8_type(), 1u, false); + emit_int_case(case_u16, ir.u16_type(), 2u, false); + emit_int_case(case_u32, ir.u32_type(), 4u, false); + emit_int_case(case_u64, ir.u64_type(), 8u, false); + + ir.start_label(case_default); + ir.make_inst(spv::OpBranch, merge); + + ir.start_label(merge); + return ir.load_variable(result_var, ir.i64_type()); +} + +// Dynamic-index access into a Function-scope array (OpVariable with array type, allocated via `alloca_variable`). +// `IRBuilder::struct_array_access` is a buffer-only helper (asserts `kStructArrayPtr`); for Function-scope arrays we +// emit `OpAccessChain` directly with the per-element pointer type. Returns a pointer-to-element that can be passed to +// `load_variable` / `store_variable`. +Value alloca_array_access(IRBuilder &ir, Value arr_var, const SType &elem_type, Value index_i32) { + SType elem_ptr_type = ir.get_pointer_type(elem_type, spv::StorageClassFunction); + Value elem_ptr = ir.new_value(elem_ptr_type, ValueKind::kVariablePtr); + ir.make_inst(spv::OpAccessChain, elem_ptr_type, elem_ptr, arr_var, index_i32); + return elem_ptr; +} + +// Compute the element index for a `kExternalTensorRead` node body leaf. The loop walks `indices[node.indices_offset .. +// node.indices_offset + 2 * node.indices_count)` as `(idx_raw, elem_stride)` pairs and accumulates `v * stride` where +// `v` is the resolved integer index. `scope_var` is the per-thread Function-scope `i64[kAdStackSizeExpr +// DeviceMaxBoundVars]` array the dispatch pre-populates per axis. A `BoundVariable` reference is encoded as a negative +// `idx_raw` per `SerializedSizeExprNode::indices`; the host encoder dense-remaps every captured chain bound-var id into +// a device-scope slot in `[0, num_axes)`, so `var_id = -(idx_raw + 1)` and we read `scope[var_id]` at runtime. +// Multi-axis chain captures (Case 3) populate one scope slot per axis before the body walk; single-axis specs collapse +// to one populated slot at index 0. +Value compute_external_read_elem_index(IRBuilder &ir, + Value bytecode_buf, + Value indices_base_word, + Value indices_offset_i32, + Value indices_count_i32, + Value scope_var) { + Value acc_var = ir.alloca_variable(ir.i32_type()); + ir.store_variable(acc_var, ir.int_immediate_number(ir.i32_type(), 0)); + Value k_var = ir.alloca_variable(ir.i32_type()); + ir.store_variable(k_var, ir.int_immediate_number(ir.i32_type(), 0)); + + Label head = ir.new_label(); + Label body = ir.new_label(); + Label cont = ir.new_label(); + Label merge = ir.new_label(); + + ir.make_inst(spv::OpBranch, head); + + ir.start_label(head); + Value k_now = ir.load_variable(k_var, ir.i32_type()); + Value cond = ir.lt(k_now, indices_count_i32); + ir.make_inst(spv::OpLoopMerge, merge, cont, spv::LoopControlMaskNone); + ir.make_inst(spv::OpBranchConditional, cond, body, merge); + + ir.start_label(body); + Value indices_off_u32 = ir.cast(ir.u32_type(), indices_offset_i32); + Value k_u32 = ir.cast(ir.u32_type(), k_now); + Value pair_base_u32 = ir.add(indices_off_u32, ir.mul(k_u32, ir.uint_immediate_number(ir.u32_type(), 2u))); + Value idx_word_u32 = ir.add(indices_base_word, pair_base_u32); + Value stride_word_u32 = ir.add(idx_word_u32, ir.uint_immediate_number(ir.u32_type(), 1u)); + Value idx_raw_i32 = load_buf_i32(ir, bytecode_buf, idx_word_u32); + Value stride_i32 = load_buf_i32(ir, bytecode_buf, stride_word_u32); + + // Resolve the raw index. `idx_raw >= 0` means a constant axis index baked at encode time; `idx_raw < 0` means a + // bound-variable reference encoded as `-(slot + 1)` where `slot` is the dense device-scope index the host encoder + // assigned to that chain axis. We narrow `scope[slot]` (i64) to i32 to match the per-thread sizer's index width; the + // recognizer rejects specs whose closed-form bound exceeds 2^31 (host launcher caps the cross-product length at u32 + // max), so the i32 narrowing is safe. + Label const_lbl = ir.new_label(); + Label var_lbl = ir.new_label(); + Label sel_merge = ir.new_label(); + Value is_const = ir.ge(idx_raw_i32, ir.int_immediate_number(ir.i32_type(), 0)); + ir.make_inst(spv::OpSelectionMerge, sel_merge, spv::SelectionControlMaskNone); + ir.make_inst(spv::OpBranchConditional, is_const, const_lbl, var_lbl); + + ir.start_label(const_lbl); + Value v_const_i32 = idx_raw_i32; + Label const_end = ir.current_label(); + ir.make_inst(spv::OpBranch, sel_merge); + + ir.start_label(var_lbl); + Value var_slot_i32 = ir.sub(ir.int_immediate_number(ir.i32_type(), -1), idx_raw_i32); + Value scope_ptr = alloca_array_access(ir, scope_var, ir.i64_type(), var_slot_i32); + Value scope_val_i64 = ir.load_variable(scope_ptr, ir.i64_type()); + Value v_var_i32 = ir.cast(ir.i32_type(), scope_val_i64); + Label var_end = ir.current_label(); + ir.make_inst(spv::OpBranch, sel_merge); + + ir.start_label(sel_merge); + PhiValue v = ir.make_phi(ir.i32_type(), 2); + v.set_incoming(0, v_const_i32, const_end); + v.set_incoming(1, v_var_i32, var_end); + + Value contribution = ir.mul(Value(v), stride_i32); + Value acc_now = ir.load_variable(acc_var, ir.i32_type()); + ir.store_variable(acc_var, ir.add(acc_now, contribution)); + ir.make_inst(spv::OpBranch, cont); + + ir.start_label(cont); + Value k_next = ir.add(k_now, ir.int_immediate_number(ir.i32_type(), 1)); + ir.store_variable(k_var, k_next); + ir.make_inst(spv::OpBranch, head); + + ir.start_label(merge); + return ir.load_variable(acc_var, ir.i32_type()); +} + +// Per-thread post-order interpreter for the body subtree. Iterates ascending node indices `0..body_node_count`, reading +// each node's POD from the bytecode buffer at `body_bytecode_offset_words + i * kNodeWords` and storing the computed +// i64 value into `vals[i]`. The `vals[]` array is a Function-scope OpVariable of `i64[kAdStackMaxReducerMax +// BodyNodes]`; per-thread footprint is 8 bytes/node-slot. Final result is `vals[body_node_count - 1]` (the root, since +// post-order encoding places the root last). +Value interpret_body(IRBuilder &ir, + Value args_buf, + Value bytecode_buf, + Value body_bytecode_offset_words, + Value body_indices_offset_words, + Value body_node_count_i32, + Value scope_var) { + // Function-scope per-thread value storage. Allocate as a fixed-size i64 array; the index used at store time is the + // current node index. + SType i64_arr_ty = ir.get_array_type(ir.i64_type(), kAdStackMaxReducerMaxBodyNodes); + Value vals_var = ir.alloca_variable(i64_arr_ty); + + Value i_var = ir.alloca_variable(ir.i32_type()); + ir.store_variable(i_var, ir.int_immediate_number(ir.i32_type(), 0)); + + Label head = ir.new_label(); + Label body_lbl = ir.new_label(); + Label cont = ir.new_label(); + Label merge = ir.new_label(); + + ir.make_inst(spv::OpBranch, head); + + ir.start_label(head); + Value i_now = ir.load_variable(i_var, ir.i32_type()); + Value loop_cond = ir.lt(i_now, body_node_count_i32); + ir.make_inst(spv::OpLoopMerge, merge, cont, spv::LoopControlMaskNone); + ir.make_inst(spv::OpBranchConditional, loop_cond, body_lbl, merge); + + ir.start_label(body_lbl); + // Compute this node's slot base word: `body_bytecode_offset_words + i * kNodeWords`. + Value i_u32 = ir.cast(ir.u32_type(), i_now); + Value slot_base = + ir.add(body_bytecode_offset_words, ir.mul(i_u32, ir.uint_immediate_number(ir.u32_type(), kNodeWords))); + Value kind_i32 = + load_buf_i32(ir, bytecode_buf, ir.add(slot_base, ir.uint_immediate_number(ir.u32_type(), kNodeWordKind))); + + Label case_const = ir.new_label(); + Label case_bv = ir.new_label(); + Label case_etr = ir.new_label(); + Label case_fl = ir.new_label(); + Label case_add = ir.new_label(); + Label case_sub = ir.new_label(); + Label case_mul = ir.new_label(); + Label case_max = ir.new_label(); + Label case_default = ir.new_label(); + Label kind_merge = ir.new_label(); + + Value computed_var = ir.alloca_variable(ir.i64_type()); + ir.store_variable(computed_var, ir.int_immediate_number(ir.i64_type(), 0)); + + ir.make_inst(spv::OpSelectionMerge, kind_merge, spv::SelectionControlMaskNone); + ir.make_inst(spv::OpSwitch, kind_i32, case_default, // + static_cast(AdStackSizeExprDeviceKind::kConst), case_const, // + static_cast(AdStackSizeExprDeviceKind::kBoundVariable), case_bv, // + static_cast(AdStackSizeExprDeviceKind::kExternalTensorRead), case_etr, // + static_cast(AdStackSizeExprDeviceKind::kFieldLoad), case_fl, // + static_cast(AdStackSizeExprDeviceKind::kAdd), case_add, // + static_cast(AdStackSizeExprDeviceKind::kSub), case_sub, // + static_cast(AdStackSizeExprDeviceKind::kMul), case_mul, // + static_cast(AdStackSizeExprDeviceKind::kMax), case_max); + + // kConst: load the i64 const_value from words [10, 11] of this slot. + ir.start_label(case_const); + { + Value const_val = load_buf_i64(ir, bytecode_buf, + ir.add(slot_base, ir.uint_immediate_number(ir.u32_type(), kNodeWordConstValueLo))); + ir.store_variable(computed_var, const_val); + ir.make_inst(spv::OpBranch, kind_merge); + } + + // kBoundVariable: read the per-thread scope slot the dispatch pre-populated for this axis. The host encoder + // dense-remaps every captured chain bound-var id to a slot in `[0, num_axes)`; the dispatch site stores + // `per_axis_begin[a] + axis_idx_a` into `scope[a]` before walking the body. + ir.start_label(case_bv); + { + Value var_id_i32 = + load_buf_i32(ir, bytecode_buf, ir.add(slot_base, ir.uint_immediate_number(ir.u32_type(), kNodeWordVarId))); + Value scope_ptr_bv = alloca_array_access(ir, scope_var, ir.i64_type(), var_id_i32); + Value scope_val_bv = ir.load_variable(scope_ptr_bv, ir.i64_type()); + ir.store_variable(computed_var, scope_val_bv); + ir.make_inst(spv::OpBranch, kind_merge); + } + + // kExternalTensorRead: PSB-load the body ndarray's element at the computed linear index. + ir.start_label(case_etr); + { + Value arg_byte_offset = load_buf_i32( + ir, bytecode_buf, ir.add(slot_base, ir.uint_immediate_number(ir.u32_type(), kNodeWordArgBufferOffset))); + Value prim_dt = + load_buf_i32(ir, bytecode_buf, ir.add(slot_base, ir.uint_immediate_number(ir.u32_type(), kNodeWordPrimDt))); + Value indices_offset = load_buf_i32( + ir, bytecode_buf, ir.add(slot_base, ir.uint_immediate_number(ir.u32_type(), kNodeWordIndicesOffset))); + Value indices_count = load_buf_i32( + ir, bytecode_buf, ir.add(slot_base, ir.uint_immediate_number(ir.u32_type(), kNodeWordIndicesCount))); + Value linear_i32 = compute_external_read_elem_index(ir, bytecode_buf, body_indices_offset_words, indices_offset, + indices_count, scope_var); + // The host encoder writes `arg_buffer_offset` in bytes (it's the byte offset of the ndarray's `data_ptr` slot + // within the kernel arg buffer). The shader reads `args_buf` as u32[], so divide by 4 to land at the right u32 word + // index. Mirrors `adstack_sizer_shader.cpp`'s same conversion. + Value arg_word_i32 = ir.make_value(spv::OpShiftRightArithmetic, ir.i32_type(), arg_byte_offset, + ir.uint_immediate_number(ir.u32_type(), 2u)); + Value arg_word_u32 = ir.cast(ir.u32_type(), arg_word_i32); + Value data_ptr_u64 = load_arg_buf_u64_ptr(ir, args_buf, arg_word_u32); + Value loaded_i64 = emit_psb_load_i64_int_only(ir, data_ptr_u64, linear_i32, prim_dt); + ir.store_variable(computed_var, loaded_i64); + ir.make_inst(spv::OpBranch, kind_merge); + } + + // kFieldLoad: PSB-load the body field's element at the indices-table-resolved linear offset. Same `[idx_a_raw, + // elem_stride_a]` indices layout as `kExternalTensorRead`, so `compute_external_read_elem_index` is reused verbatim + // for the bound-var dense-remap. The base pointer comes from the encoder's pre-baked `const_value` (= snode tree + // root_psb + place_byte_offset_in_root); the kernel arg buffer is irrelevant since FieldLoads target SNodes, not + // ndarrays. `track_physical_buffer` on the SNode tree root buffer is the dispatch-site's responsibility (mirrors the + // ndarray residency hint that kExternalTensorRead requires on Apple Silicon). + ir.start_label(case_fl); + { + Value prim_dt = + load_buf_i32(ir, bytecode_buf, ir.add(slot_base, ir.uint_immediate_number(ir.u32_type(), kNodeWordPrimDt))); + Value const_lo_idx = ir.add(slot_base, ir.uint_immediate_number(ir.u32_type(), kNodeWordConstValueLo)); + Value base_i64 = load_buf_i64(ir, bytecode_buf, const_lo_idx); + Value base_u64 = ir.make_value(spv::OpBitcast, ir.u64_type(), base_i64); + Value indices_offset = load_buf_i32( + ir, bytecode_buf, ir.add(slot_base, ir.uint_immediate_number(ir.u32_type(), kNodeWordIndicesOffset))); + Value indices_count = load_buf_i32( + ir, bytecode_buf, ir.add(slot_base, ir.uint_immediate_number(ir.u32_type(), kNodeWordIndicesCount))); + Value linear_i32 = compute_external_read_elem_index(ir, bytecode_buf, body_indices_offset_words, indices_offset, + indices_count, scope_var); + Value loaded_i64 = emit_psb_load_i64_int_only(ir, base_u64, linear_i32, prim_dt); + ir.store_variable(computed_var, loaded_i64); + ir.make_inst(spv::OpBranch, kind_merge); + } + + // Helper to load `vals[op_idx]` for the binary arithmetic cases. The op-index is read from the slot's `operand_a` / + // `operand_b` fields, both of which the post-order encoding guarantees are < `i_now` so the value is already + // computed. + auto load_val_at = [&](uint32_t word_off) -> Value { + Value op_idx_i32 = + load_buf_i32(ir, bytecode_buf, ir.add(slot_base, ir.uint_immediate_number(ir.u32_type(), word_off))); + Value ptr = alloca_array_access(ir, vals_var, ir.i64_type(), op_idx_i32); + return ir.load_variable(ptr, ir.i64_type()); + }; + + ir.start_label(case_add); + { + Value a = load_val_at(kNodeWordOperandA); + Value b = load_val_at(kNodeWordOperandB); + ir.store_variable(computed_var, ir.add(a, b)); + ir.make_inst(spv::OpBranch, kind_merge); + } + + ir.start_label(case_sub); + { + Value a = load_val_at(kNodeWordOperandA); + Value b = load_val_at(kNodeWordOperandB); + Value diff = ir.sub(a, b); + // Match the host evaluator's saturating-sub behaviour for `SizeExpr::Kind::Sub` (clamps to 0). Per-thread sizes are + // non-negative by construction; signed subtraction would let a negative value poison the running max if a body + // subtree's `arr_a[i] < arr_b[i]` for some thread. + Value zero_i64 = ir.int_immediate_number(ir.i64_type(), 0); + Value is_neg = ir.lt(diff, zero_i64); + Value clamped = ir.make_value(spv::OpSelect, ir.i64_type(), is_neg, zero_i64, diff); + ir.store_variable(computed_var, clamped); + ir.make_inst(spv::OpBranch, kind_merge); + } + + ir.start_label(case_mul); + { + Value a = load_val_at(kNodeWordOperandA); + Value b = load_val_at(kNodeWordOperandB); + ir.store_variable(computed_var, ir.mul(a, b)); + ir.make_inst(spv::OpBranch, kind_merge); + } + + ir.start_label(case_max); + { + Value a = load_val_at(kNodeWordOperandA); + Value b = load_val_at(kNodeWordOperandB); + Value gt = ir.make_value(spv::OpSGreaterThan, ir.bool_type(), a, b); + Value m = ir.make_value(spv::OpSelect, ir.i64_type(), gt, a, b); + ir.store_variable(computed_var, m); + ir.make_inst(spv::OpBranch, kind_merge); + } + + ir.start_label(case_default); + ir.make_inst(spv::OpBranch, kind_merge); + + ir.start_label(kind_merge); + Value computed = ir.load_variable(computed_var, ir.i64_type()); + // Store into vals[i]. + Value vals_slot_ptr = alloca_array_access(ir, vals_var, ir.i64_type(), i_now); + ir.store_variable(vals_slot_ptr, computed); + ir.make_inst(spv::OpBranch, cont); + + ir.start_label(cont); + Value i_next = ir.add(i_now, ir.int_immediate_number(ir.i32_type(), 1)); + ir.store_variable(i_var, i_next); + ir.make_inst(spv::OpBranch, head); + + ir.start_label(merge); + // Root index = body_node_count - 1. + Value root_idx = ir.sub(body_node_count_i32, ir.int_immediate_number(ir.i32_type(), 1)); + Value root_ptr = alloca_array_access(ir, vals_var, ir.i64_type(), root_idx); + return ir.load_variable(root_ptr, ir.i64_type()); +} + +} // namespace + +std::vector build_adstack_max_reducer_spirv(Arch arch, const DeviceCapabilityConfig *caps) { + if (!caps->get(DeviceCapability::spirv_has_physical_storage_buffer)) { + return {}; + } + if (!caps->get(DeviceCapability::spirv_has_int64)) { + return {}; + } + + IRBuilder ir(arch, caps); + ir.init_header(); + + // Storage-buffer bindings (set 0). The output buffer holds two u32 slots per spec: even index = u32 atomic-max + // running max, odd index = u32 atomic-or overflow flag (0 = max fits in u32, non-zero = at least one thread observed + // a body i64 value above `UINT32_MAX` and the host should fall back to host eval). u32 atomics are universally + // translated through spirv-cross's MSL backend whereas i64 atomic-max is not, which is what unlocks the Metal / + // Vulkan-via-MoltenVK paths. + Value args_buf = ir.buffer_argument(ir.u32_type(), 0, 0, "adstack_max_reducer_args"); + Value output_buf = ir.buffer_argument(ir.u32_type(), 0, 1, "adstack_max_reducer_output"); + Value params_buf = ir.buffer_argument(ir.u32_type(), 0, 2, "adstack_max_reducer_params"); + Value bytecode_buf = ir.buffer_argument(ir.u32_type(), 0, 3, "adstack_max_reducer_bytecode"); + + Value main_func = ir.new_function(); + ir.start_function(main_func); + ir.set_work_group_size({static_cast(kAdStackMaxReducerWorkgroupSize), 1, 1}); + + Value gid_u32 = ir.get_global_invocation_id(0); + + // Load params at the top of `main`. spirv-opt CSEs the redundant loads if any, but the explicit hoist makes the + // shader's data flow easier to read against the host POD. Per-axis arrays are loaded inside the per-thread loop since + // they are indexed by axis k; the loop pulls one axis word per iteration. + Value output_slot = load_buf_u32( + ir, params_buf, ir.uint_immediate_number(ir.u32_type(), AdStackMaxReducerParams::kWordOffsetOutputSlot)); + Value length = + load_buf_u32(ir, params_buf, ir.uint_immediate_number(ir.u32_type(), AdStackMaxReducerParams::kWordOffsetLength)); + Value num_axes_u32 = load_buf_u32( + ir, params_buf, ir.uint_immediate_number(ir.u32_type(), AdStackMaxReducerParams::kWordOffsetNumAxes)); + Value body_bytecode_offset_words = load_buf_u32( + ir, params_buf, + ir.uint_immediate_number(ir.u32_type(), AdStackMaxReducerParams::kWordOffsetBodyBytecodeOffsetWords)); + Value body_node_count_u32 = load_buf_u32( + ir, params_buf, ir.uint_immediate_number(ir.u32_type(), AdStackMaxReducerParams::kWordOffsetBodyNodeCount)); + Value body_indices_offset_words = + load_buf_u32(ir, params_buf, + ir.uint_immediate_number(ir.u32_type(), AdStackMaxReducerParams::kWordOffsetBodyIndicesOffsetWords)); + + // Per-thread strided iteration: each thread walks `kElementsPerThread` cross-product cells at stride `total_threads` + // before atomic-maxing the per-thread running max into the output slot. The host launcher caps `num_workgroups_x` at + // the Vulkan / Metal `maxComputeWorkGroupCount[0]` minimum of 65535; striding by `total_threads` lets the dispatch + // cover spec lengths up to `kElementsPerThread * 128 * 65535 = ~536M` cross- product cells without dropping any. + constexpr uint32_t kElementsPerThread = 64u; + + Value body_node_count_i32 = ir.make_value(spv::OpBitcast, ir.i32_type(), body_node_count_u32); + Value num_axes_i32 = ir.make_value(spv::OpBitcast, ir.i32_type(), num_axes_u32); + + // Per-thread scope array: `i64[kAdStackSizeExprDeviceMaxBoundVars]`. The dispatch pre-populates one slot per captured + // chain axis before each cross-product cell; the body interpreter's `kBoundVariable` and `kExternal TensorRead` cases + // read from this array at the dense-remapped slot id. + SType scope_arr_ty = ir.get_array_type(ir.i64_type(), kAdStackSizeExprDeviceMaxBoundVars); + Value scope_var = ir.alloca_variable(scope_arr_ty); + // Zero-init every scope slot once. The cross-product loop only writes the `[0, num_axes)` prefix, so an out-of-range + // slot read (which the body grammar disallows but the device interpreter still permits) lands on a deterministic zero + // rather than uninitialised memory. + { + Value init_k_var = ir.alloca_variable(ir.i32_type()); + ir.store_variable(init_k_var, ir.int_immediate_number(ir.i32_type(), 0)); + Label init_head = ir.new_label(); + Label init_body = ir.new_label(); + Label init_cont = ir.new_label(); + Label init_merge = ir.new_label(); + ir.make_inst(spv::OpBranch, init_head); + ir.start_label(init_head); + Value init_k_now = ir.load_variable(init_k_var, ir.i32_type()); + Value init_cond = + ir.lt(init_k_now, ir.int_immediate_number(ir.i32_type(), static_cast(kAdStackSizeExprDeviceMaxBoundVars))); + ir.make_inst(spv::OpLoopMerge, init_merge, init_cont, spv::LoopControlMaskNone); + ir.make_inst(spv::OpBranchConditional, init_cond, init_body, init_merge); + ir.start_label(init_body); + Value init_slot_ptr = alloca_array_access(ir, scope_var, ir.i64_type(), init_k_now); + ir.store_variable(init_slot_ptr, ir.int_immediate_number(ir.i64_type(), 0)); + ir.make_inst(spv::OpBranch, init_cont); + ir.start_label(init_cont); + ir.store_variable(init_k_var, ir.add(init_k_now, ir.int_immediate_number(ir.i32_type(), 1))); + ir.make_inst(spv::OpBranch, init_head); + ir.start_label(init_merge); + } + + // Per-thread running max + overflow flag, materialised in Function-scope variables and folded into the global output + // buffer at the end of the strided loop with a single `OpAtomicUMax` / `OpAtomicOr`. + Value local_max_var = ir.alloca_variable(ir.u32_type()); + ir.store_variable(local_max_var, ir.uint_immediate_number(ir.u32_type(), 0u)); + Value local_overflow_var = ir.alloca_variable(ir.u32_type()); + ir.store_variable(local_overflow_var, ir.uint_immediate_number(ir.u32_type(), 0u)); + + // Strided loop: for k in [0, kElementsPerThread), idx = (gid + k * total_threads). `total_threads = num_workgroups_x + // * workgroup_size_x` matches what the host computed when sizing the dispatch. + Value workgroup_size_v = ir.uint_immediate_number(ir.u32_type(), kAdStackMaxReducerWorkgroupSize); + Value num_wg_x = ir.get_num_work_groups(0); + Value total_threads = ir.mul(num_wg_x, workgroup_size_v); + + Value k_var = ir.alloca_variable(ir.u32_type()); + ir.store_variable(k_var, ir.uint_immediate_number(ir.u32_type(), 0u)); + Label k_head = ir.new_label(); + Label k_body = ir.new_label(); + Label k_cont = ir.new_label(); + Label k_merge = ir.new_label(); + ir.make_inst(spv::OpBranch, k_head); + + ir.start_label(k_head); + Value k_now = ir.load_variable(k_var, ir.u32_type()); + Value k_in_range = ir.lt(k_now, ir.uint_immediate_number(ir.u32_type(), kElementsPerThread)); + ir.make_inst(spv::OpLoopMerge, k_merge, k_cont, spv::LoopControlMaskNone); + ir.make_inst(spv::OpBranchConditional, k_in_range, k_body, k_merge); + + ir.start_label(k_body); + { + Value stride_off = ir.mul(k_now, total_threads); + Value idx_u32 = ir.add(gid_u32, stride_off); + Label do_block = ir.new_label(); + Label skip_block = ir.new_label(); + Label idx_merge = ir.new_label(); + Value idx_in_range = ir.lt(idx_u32, length); + ir.make_inst(spv::OpSelectionMerge, idx_merge, spv::SelectionControlMaskNone); + ir.make_inst(spv::OpBranchConditional, idx_in_range, do_block, skip_block); + + ir.start_label(do_block); + { + // Decompose the linear cross-product index `idx_u32` into per-axis indices via mod / div outermost-first (axis + // `num_axes - 1` = innermost = fastest-varying). Per axis, populate `scope[per_axis_var_id[a]] = + // per_axis_begin[a] + axis_idx_a` so the body interpreter can read the bound value via the dense-remapped slot id + // encoded in the body bytecode. The recognizer grammar caps every axis's `begin + length` at i32 max so the + // per-axis index decomposition stays in u32 / i64 without overflow concerns. + Value rem_var = ir.alloca_variable(ir.u32_type()); + ir.store_variable(rem_var, idx_u32); + Value axis_iter_var = ir.alloca_variable(ir.i32_type()); + ir.store_variable(axis_iter_var, num_axes_i32); + Label axis_head = ir.new_label(); + Label axis_body = ir.new_label(); + Label axis_cont = ir.new_label(); + Label axis_merge = ir.new_label(); + ir.make_inst(spv::OpBranch, axis_head); + ir.start_label(axis_head); + Value axis_now = ir.load_variable(axis_iter_var, ir.i32_type()); + Value axis_cond = ir.gt(axis_now, ir.int_immediate_number(ir.i32_type(), 0)); + ir.make_inst(spv::OpLoopMerge, axis_merge, axis_cont, spv::LoopControlMaskNone); + ir.make_inst(spv::OpBranchConditional, axis_cond, axis_body, axis_merge); + ir.start_label(axis_body); + { + Value axis_idx_i32 = ir.sub(axis_now, ir.int_immediate_number(ir.i32_type(), 1)); + Value axis_idx_u32 = ir.cast(ir.u32_type(), axis_idx_i32); + Value len_word_idx = ir.add( + ir.uint_immediate_number(ir.u32_type(), AdStackMaxReducerParams::kWordOffsetPerAxisLength), axis_idx_u32); + Value lo_word_idx = ir.add( + ir.uint_immediate_number(ir.u32_type(), AdStackMaxReducerParams::kWordOffsetPerAxisBeginLo), axis_idx_u32); + Value hi_word_idx = ir.add( + ir.uint_immediate_number(ir.u32_type(), AdStackMaxReducerParams::kWordOffsetPerAxisBeginHi), axis_idx_u32); + Value var_word_idx = ir.add( + ir.uint_immediate_number(ir.u32_type(), AdStackMaxReducerParams::kWordOffsetPerAxisVarId), axis_idx_u32); + Value len_a = load_buf_u32(ir, params_buf, len_word_idx); + Value begin_lo_a = load_buf_u32(ir, params_buf, lo_word_idx); + Value begin_hi_a = load_buf_u32(ir, params_buf, hi_word_idx); + Value var_id_a = load_buf_i32(ir, params_buf, var_word_idx); + Value rem_now = ir.load_variable(rem_var, ir.u32_type()); + Value idx_a_u32 = ir.mod(rem_now, len_a); + Value rem_next = ir.div(rem_now, len_a); + ir.store_variable(rem_var, rem_next); + // Reassemble per-axis begin as i64 and add the axis index. The cast-to-i32 step is intentional - the host + // encoder caps each axis's begin + length at i32 max so the i64 sum always fits. + Value lo64_a = ir.cast(ir.u64_type(), begin_lo_a); + Value hi64_a = ir.cast(ir.u64_type(), begin_hi_a); + Value shift_a = ir.uint_immediate_number(ir.u64_type(), 32u); + Value hi_shifted_a = ir.make_value(spv::OpShiftLeftLogical, ir.u64_type(), hi64_a, shift_a); + Value begin_u64_a = ir.make_value(spv::OpBitwiseOr, ir.u64_type(), lo64_a, hi_shifted_a); + Value begin_i64_a = ir.make_value(spv::OpBitcast, ir.i64_type(), begin_u64_a); + Value idx_a_i64 = ir.cast(ir.i64_type(), idx_a_u32); + Value scope_val = ir.add(begin_i64_a, idx_a_i64); + Value scope_slot_ptr = alloca_array_access(ir, scope_var, ir.i64_type(), var_id_a); + ir.store_variable(scope_slot_ptr, scope_val); + ir.store_variable(axis_iter_var, axis_idx_i32); + ir.make_inst(spv::OpBranch, axis_cont); + } + ir.start_label(axis_cont); + ir.make_inst(spv::OpBranch, axis_head); + ir.start_label(axis_merge); + + Value result_i64 = interpret_body(ir, args_buf, bytecode_buf, body_bytecode_offset_words, + body_indices_offset_words, body_node_count_i32, scope_var); + + // Clamp negative body values to 0; the recognized grammar's leaves are integer ndarray reads + integer arithmetic + // that can dip below zero through a `Sub` clamp (the device interpreter's `kSub` returns 0 on negative diffs to + // match the host evaluator). Once non-negative, branch on whether the value fits in `UINT32_MAX`: if it does, + // fold into `local_max`; if it does not, set `local_overflow` (the host falls back to direct host-eval for that + // spec). + Value zero_i64 = ir.int_immediate_number(ir.i64_type(), 0); + Value u32_max_i64 = ir.int_immediate_number(ir.i64_type(), static_cast(0xFFFFFFFFll)); + Value is_neg = ir.make_value(spv::OpSLessThan, ir.bool_type(), result_i64, zero_i64); + Value clamped_pos = ir.make_value(spv::OpSelect, ir.i64_type(), is_neg, zero_i64, result_i64); + Value overflow_cond = ir.make_value(spv::OpSGreaterThan, ir.bool_type(), clamped_pos, u32_max_i64); + + Label of_then = ir.new_label(); + Label of_else = ir.new_label(); + Label of_merge = ir.new_label(); + ir.make_inst(spv::OpSelectionMerge, of_merge, spv::SelectionControlMaskNone); + ir.make_inst(spv::OpBranchConditional, overflow_cond, of_then, of_else); + + ir.start_label(of_then); + { + ir.store_variable(local_overflow_var, ir.uint_immediate_number(ir.u32_type(), 1u)); + ir.make_inst(spv::OpBranch, of_merge); + } + + ir.start_label(of_else); + { + Value value_u32 = ir.cast(ir.u32_type(), clamped_pos); + Value local_max_now = ir.load_variable(local_max_var, ir.u32_type()); + Value bigger = ir.make_value(spv::OpUGreaterThan, ir.bool_type(), value_u32, local_max_now); + Value new_local_max = ir.make_value(spv::OpSelect, ir.u32_type(), bigger, value_u32, local_max_now); + ir.store_variable(local_max_var, new_local_max); + ir.make_inst(spv::OpBranch, of_merge); + } + + ir.start_label(of_merge); + ir.make_inst(spv::OpBranch, idx_merge); + } + + ir.start_label(skip_block); + ir.make_inst(spv::OpBranch, idx_merge); + + ir.start_label(idx_merge); + ir.make_inst(spv::OpBranch, k_cont); + } + + ir.start_label(k_cont); + Value k_next = ir.add(k_now, ir.uint_immediate_number(ir.u32_type(), 1u)); + ir.store_variable(k_var, k_next); + ir.make_inst(spv::OpBranch, k_head); + + ir.start_label(k_merge); + { + // Two u32 slots per spec: `output_buf[2 * output_slot] = u32 atomic-max running value`, `output_buf[2 * output_slot + // + 1] = u32 atomic-or overflow flag`. Per-launch host clears both to 0 so the first matching thread wins the + // max-or; on a second or later launch the slots have been re-zeroed in the launcher's pre-dispatch clear. Threads + // whose `local_max == 0` skip the atomic-max to save the bus contention on the all-zero workload. + Value slot_idx_value = ir.mul(output_slot, ir.uint_immediate_number(ir.u32_type(), 2u)); + Value slot_idx_overflow = ir.add(slot_idx_value, ir.uint_immediate_number(ir.u32_type(), 1u)); + Value value_ptr = ir.struct_array_access(ir.u32_type(), output_buf, slot_idx_value); + Value overflow_ptr = ir.struct_array_access(ir.u32_type(), output_buf, slot_idx_overflow); + + Value local_max_final = ir.load_variable(local_max_var, ir.u32_type()); + ir.make_value(spv::OpAtomicUMax, ir.u32_type(), value_ptr, /*scope=*/ir.const_i32_one_, + /*semantics=*/ir.const_i32_zero_, local_max_final); + Value local_overflow_final = ir.load_variable(local_overflow_var, ir.u32_type()); + ir.make_value(spv::OpAtomicOr, ir.u32_type(), overflow_ptr, /*scope=*/ir.const_i32_one_, + /*semantics=*/ir.const_i32_zero_, local_overflow_final); + } + ir.make_inst(spv::OpReturn); + ir.make_inst(spv::OpFunctionEnd); + + std::vector entry_args = {args_buf, output_buf, params_buf, bytecode_buf}; + ir.commit_kernel_function(main_func, "main", entry_args, {static_cast(kAdStackMaxReducerWorkgroupSize), 1, 1}); + + return ir.finalize(); +} + +} // namespace quadrants::lang::spirv diff --git a/quadrants/codegen/spirv/adstack_max_reducer_shader.h b/quadrants/codegen/spirv/adstack_max_reducer_shader.h new file mode 100644 index 0000000000..2ac91ab639 --- /dev/null +++ b/quadrants/codegen/spirv/adstack_max_reducer_shader.h @@ -0,0 +1,125 @@ +#pragma once + +#include +#include + +#include + +#include "quadrants/ir/static_adstack_max_reducer_device.h" +#include "quadrants/rhi/arch.h" +#include "quadrants/rhi/public_device.h" + +namespace quadrants::lang::spirv { + +// Builds the SPIR-V compute shader that evaluates a captured `StaticAdStackMaxReducerSpec`'s body subtree over a thread +// range and atomic-maxes the result into a per-spec slot of `BufferType::AdStackMaxReducerOutput`. Dispatched once per +// captured `MaxOverRange` node before the main task on the max-reducer path; the resulting per-spec value is +// substituted as a `Const` into the per-stack `SerializedSizeExpr` tree by `substitute_precomputed_max_over_range` +// before any of the three eval paths (host fast path, SPIR-V on-device sizer, LLVM device sizer) walks it. +// +// The shader is generic (parametrised at dispatch time by the parameter blob in binding 2 + the body bytecode in +// binding 3) and is compiled once per `GfxRuntime`. Host responsibility per dispatch: +// - Write the parameter blob (`AdStackMaxReducerParams` below) into the shared params storage buffer at the spec's +// descriptor-aligned offset, bound to slot 2 with a per-spec `VkDescriptorBufferInfo::offset`. +// - Encode the body subtree into the shared bytecode storage buffer at the spec's `body_bytecode_offset_words` slot +// using the existing `AdStackSizeExprDeviceNode` POD format from `quadrants/ir/adstack_size_expr_device.h`. Bind to +// slot 3. +// - Bind the kernel arg buffer to slot 0 (the same arg buffer the main kernel uses) so `kExternalTensorRead` body +// leaves can resolve their ndarray data pointers via the same byte-offset convention the main kernel uses. +// - Bind the per-kernel `AdStackMaxReducerOutput` u64 buffer to slot 1 with the matching `output_slot` cleared. +// - Dispatch `ceil(length / kAdStackMaxReducerWorkgroupSize)` workgroups of `kAdStackMaxReducerWorkgroupSize` threads. +// After dispatch + sync the slot's value equals `max over i in [begin, begin + length): body[i]` interpreted in i64; +// the host reads that value and substitutes it as a `Const` into the per-stack `SizeExpr` tree. +// +// Required device capabilities: `spirv_has_physical_storage_buffer` + `spirv_has_int64`. The first is needed because +// every body leaf reads through the ndarray data pointer the kernel arg buffer carries (PSB load path, mirroring the +// main kernel's ndarray access); the second is needed for non-atomic i64 arithmetic inside the body interpreter +// (recognized bodies operate on i32 ndarray reads but the running scalar widens to i64 to match the host evaluator's +// overflow semantics). The output atomic itself is u32 - the shader stores a `u32` max plus a `u32` overflow flag per +// spec, atom-max'd / atom-or'd respectively, so spirv-cross's MSL backend (which rejects 64-bit atomics at `MSL +// currently does not support 64-bit atomics`) translates cleanly through to Metal / Vulkan-via-MoltenVK. On devices +// missing PSB or i64 support the function returns an empty vector and the runtime hard-errors at dispatch-time +// (`adstack_max_reducer_launch.cpp`'s `QD_ERROR_IF` gate). Silently falling through to the per-thread sizer's capped +// path would corrupt reverse-mode gradients (the captured `MaxOverRange`'s 1<<24-truncated result undersizes the heap), +// so failing loud is strictly safer. Quadrants's official Vulkan target is `VK_API_VERSION_1_3`, which promotes both +// `VK_KHR_buffer_device_address` and `VK_KHR_shader_int64` (the underlying Vulkan caps) into core. The empty-return +// branch is forward-looking for any non-Vulkan-1.3 device that might still surface here. +std::vector build_adstack_max_reducer_spirv(Arch arch, const DeviceCapabilityConfig *caps); + +// Compute-shader workgroup size (x dimension; y and z are 1). Power-of-two and a multiple of typical subgroup widths on +// Metal / Vulkan so the workgroup-shared-memory reduction tree contracts at full subgroup width on every step. Host +// launcher uses this to compute `num_workgroups_x = (length + kAdStackMaxReducerWorkgroupSize - 1) / +// kAdStackMaxReducerWorkgroupSize` per dispatch. +constexpr uint32_t kAdStackMaxReducerWorkgroupSize = 128; + +// Maximum number of `AdStackSizeExprDeviceNode`s a single spec's body bytecode may contain. The shader's per-thread +// post-order interpreter stores per-node i64 values in a Function-scope array sized by this constant; bumping it raises +// the per-thread stack footprint by 8 bytes/node. Recognizer-grammar bodies observed in practice have 3-5 nodes; the +// 64-node cap leaves several orders of magnitude of headroom while keeping the per-thread stack at 512 bytes (well +// below Metal's 4 KiB per-invocation private-memory budget). The host encoder hard-errors when a body subtree exceeds +// this cap so the shader's array bounds are statically known. +constexpr uint32_t kAdStackMaxReducerMaxBodyNodes = 64; + +// Layout of the parameter blob the host writes into binding 2 before each dispatch. POD; keep field order in sync with +// the shader's compile-time word-offset constants in `adstack_max_reducer_shader.cpp`. Fields not relevant to a +// particular spec (e.g. `begin_hi` when `begin` fits in 32 bits) are zero-initialised by the host launcher. +struct AdStackMaxReducerParams { + // Slot index in the per-kernel `BufferType::AdStackMaxReducerOutput` u64 array that this dispatch's atomic-maxes + // accumulate into. Keyed by `(registry_id, stack_id, mor_node_idx)` packed into a single u32 by the host launcher's + // `MaxReducerCacheKey -> output_slot` allocator (mirrors how `AdStackBoundReducerParams::task_id_in_kernel` is + // assigned to `BufferType::AdStackRowCounter` slots). + uint32_t output_slot; + // Total number of cross-product iterations to dispatch over (product of every axis's `end - begin`, host-evaluated + // against the live ctx by `evaluate_adstack_size_expr_at_node` over the closed-form per-axis `begin` and `end` + // subtrees). Threads with `gid >= length` early-return so dispatch can be rounded up to the workgroup-size multiple. + uint32_t length; + // Number of captured chain axes (1..kAdStackMaxReducerMaxAxes). Axis 0 is the outermost `MaxOverRange`, axis + // `num_axes - 1` is the innermost. Single-axis specs set `num_axes == 1` and use only the first slot of every + // per-axis array below. + uint32_t num_axes; + // u32 word offset into the shared bytecode buffer (binding 3) where this spec's body bytecode begins. The bytecode is + // laid out as `kAdStackSizeExprDeviceMaxBoundVars`-renumbered `AdStackSizeExprDeviceNode`s in post-order, plus a + // trailing index-entry table at offset `body_bytecode_offset_words + body_node_count * kNodeWords`. The shader walks + // ascending node indices `0..body_node_count` reading nodes at `body_bytecode_offset_words + i * kNodeWords`. + uint32_t body_bytecode_offset_words; + // Number of nodes in this spec's body bytecode. Must satisfy `body_node_count <= kAdStackMaxReducerMaxBodyNodes`; the + // host encoder checks this and routes the spec back to the capped fallback path if exceeded. + uint32_t body_node_count; + // u32 word offset within the shared bytecode buffer where this spec's index-entry table begins (i.e. + // `body_bytecode_offset_words + body_node_count * kNodeWordsPerNode`). Cached here rather than recomputed in the + // shader so the shader can index the table without a multiply. + uint32_t body_indices_offset_words; + // Per-axis iteration length (`end - begin`), ordered outermost-first. Indices beyond `num_axes` are zero-padded. The + // shader decomposes the linear thread index `gid` into per-axis indices via mod / div over these lengths. + uint32_t per_axis_length[kAdStackMaxReducerMaxAxes]; + // Per-axis iteration base (`begin`) split into low + high 32-bit halves. The shader reassembles `(per_axis_begin_hi + // << 32) | per_axis_begin_lo` per axis as i64 and feeds `axis_begin + axis_idx` into the body's bound variable for + // that axis (`scope.values[per_axis_var_id[k]] = axis_begin + axis_idx`). + uint32_t per_axis_begin_lo[kAdStackMaxReducerMaxAxes]; + uint32_t per_axis_begin_hi[kAdStackMaxReducerMaxAxes]; + // Per-axis device-scope slot id, dense-remapped by the host encoder into `[0, num_axes)`. Used by the body + // interpreter's `kBoundVariable` case (`scope[var_id]`) and by `compute_external_read_elem_index` when a body + // ndarray-read references the chain axis as `-(slot + 1)`. + int32_t per_axis_var_id[kAdStackMaxReducerMaxAxes]; + + // Offset into the parameter blob (in u32 words) for each field; published to the shader and the host launcher as + // compile-time constants so each side reads/writes the same slots without a separate header serialisation step. The + // per-axis arrays land at fixed bases below so the shader's per-axis loops can compute element word offsets via + // `kWordOffsetPerAxis... + k` without an extra param read. + static constexpr uint32_t kWordOffsetOutputSlot = 0; + static constexpr uint32_t kWordOffsetLength = 1; + static constexpr uint32_t kWordOffsetNumAxes = 2; + static constexpr uint32_t kWordOffsetBodyBytecodeOffsetWords = 3; + static constexpr uint32_t kWordOffsetBodyNodeCount = 4; + static constexpr uint32_t kWordOffsetBodyIndicesOffsetWords = 5; + static constexpr uint32_t kWordOffsetPerAxisLength = 6; + static constexpr uint32_t kWordOffsetPerAxisBeginLo = + kWordOffsetPerAxisLength + static_cast(kAdStackMaxReducerMaxAxes); + static constexpr uint32_t kWordOffsetPerAxisBeginHi = + kWordOffsetPerAxisBeginLo + static_cast(kAdStackMaxReducerMaxAxes); + static constexpr uint32_t kWordOffsetPerAxisVarId = + kWordOffsetPerAxisBeginHi + static_cast(kAdStackMaxReducerMaxAxes); + static constexpr uint32_t kNumWords = kWordOffsetPerAxisVarId + static_cast(kAdStackMaxReducerMaxAxes); +}; + +} // namespace quadrants::lang::spirv diff --git a/quadrants/codegen/spirv/adstack_sizer_shader.cpp b/quadrants/codegen/spirv/adstack_sizer_shader.cpp index f6074397ed..82e49bcf41 100644 --- a/quadrants/codegen/spirv/adstack_sizer_shader.cpp +++ b/quadrants/codegen/spirv/adstack_sizer_shader.cpp @@ -215,6 +215,11 @@ struct ShaderState { Value bytecode_buf; Value metadata_buf; Value args_buf; + // Word index inside `metadata_buf` of the trailing overflow-flag slot. Computed once per dispatch in `main` (`2 + 2 * + // n_stacks`); the per-stack walker writes 1 here when it observes a `MaxOverRange` whose iteration count exceeds the + // `1<<24` cap. The host launcher's post-readback path raises a `QuadrantsAssertionError` when the slot is non-zero, + // so the cap-hit surfaces as a clean error rather than an under-bounded heap. + Value overflow_flag_word_var; }; // `values_arr` is a private, function-local i64 array of size `kMaxNodes` used to memoise the value of every @@ -754,17 +759,40 @@ void emit_tree_eval_loop(IRBuilder &ir, const ShaderState &st) { { // scope[var_id] = begin store_scope_at(ir, st, var_id_i32, begin_i64); - // Push pending frame: pending[sp] = {...}; sp += 1. - // `pending_end_arr` is clamped to `min(end, begin + kMaxOverRangeIterations)` so the advance loop - // silently stops after the cap instead of running on-device until the driver's TDR fires. Matches - // the LLVM interpreter's `break` at the same `1 << 24` threshold and the host evaluator's hard - // QD_ERROR; on a single-thread `1x1x1` dispatch, unbounded iteration is the only one of the three - // paths that could hang the kernel rather than surface as a clean error at `qd.sync()`. + // Push pending frame: pending[sp] = {...}; sp += 1. `pending_end_arr` is clamped to `begin` when the iteration + // count exceeds the cap, so the advance loop walks zero iterations and the dispatch returns within bounded time + // even on the worst-case shape; the cap-hit also writes 1 into the trailing overflow-flag slot of `metadata_buf`, + // and the host post-readback raises a `QuadrantsAssertionError` when the slot is non-zero. Matches the host + // evaluator's `QD_ERROR_IF` in `adstack_size_expr_eval.cpp::evaluate_node` and the LLVM device sizer's + // `scope.overflow_observed` path. Recognized `MaxOverRange` shapes are dispatched in parallel by the max-reducer + // and substituted to a `Const` before the sizer walks the tree, so this path is reachable only for out-of-grammar + // shapes whose iteration count exceeds the cap. constexpr int64_t kMaxOverRangeIterations = int64_t{1} << 24; Value cap_delta = ir.int_immediate_number(ir.i64_type(), kMaxOverRangeIterations); Value cap_end = ir.add(begin_i64, cap_delta); Value end_gt_cap = ir.gt(end_i64, cap_end); - Value effective_end = ir.select(end_gt_cap, cap_end, end_i64); + // Cap-hit collapses the walk: `effective_end = begin` so no iterations run. The overflow flag below is the signal + // the host actually consumes; the cached `max_size` value falls through to its `max(_, 1)` floor and the heap is + // never used because the host raises before the main kernel launches. + Value effective_end = ir.select(end_gt_cap, begin_i64, end_i64); + + // Cap-hit overflow signal. Single-threaded dispatch, so a plain store rather than an atomic suffices. The slot is + // initialised to 0 by the host before dispatch; the value sticks at 1 for the remainder of the dispatch once any + // `MaxOverRange` walk in this task lands here, and the host post-readback path picks it up. + Label cap_then = ir.new_label(); + Label cap_skip = ir.new_label(); + Label cap_merge = ir.new_label(); + ir.make_inst(spv::OpSelectionMerge, cap_merge, spv::SelectionControlMaskNone); + ir.make_inst(spv::OpBranchConditional, end_gt_cap, cap_then, cap_skip); + ir.start_label(cap_then); + { + Value overflow_word = ir.load_variable(st.overflow_flag_word_var, ir.u32_type()); + store_buf_u32(ir, st.metadata_buf, overflow_word, ir.uint_immediate_number(ir.u32_type(), 1u)); + ir.make_inst(spv::OpBranch, cap_merge); + } + ir.start_label(cap_skip); + ir.make_inst(spv::OpBranch, cap_merge); + ir.start_label(cap_merge); Value sp_val = ir.load_variable(st.sp_var, ir.i32_type()); ir.store_variable(array_i32_access_ptr(ir, st.scratch_i32_buf, kI32BasePendingMorIdx, sp_val), current_now); Value body_start = ir.add(op_b_i32, ir.int_immediate_number(ir.i32_type(), 1)); @@ -891,6 +919,15 @@ std::vector build_adstack_sizer_spirv(Arch arch, const DeviceCapabilit Value n_stacks_u32 = load_buf_u32(ir, bytecode_buf, ir.uint_immediate_number(ir.u32_type(), kHeaderOffNStacks)); Value total_nodes_u32 = load_buf_u32(ir, bytecode_buf, ir.uint_immediate_number(ir.u32_type(), kHeaderOffTotalNodes)); + // Cache the trailing overflow-flag slot's word index. The metadata layout is `[stride_float, stride_int, off0, max0, + // off1, max1, ..., overflow_flag]` so the slot lives at index `2 + 2 * n_stacks`. The walker writes 1 here on a + // cap-hit (see `kMaxOverRangeIterations` branch in the per-stack tree-eval loop); the host post-readback in + // `adstack_sizer_launch.cpp` checks the slot and raises if non-zero. + st.overflow_flag_word_var = ir.alloca_variable(ir.u32_type()); + Value overflow_word_idx = ir.add(ir.uint_immediate_number(ir.u32_type(), 2u), + ir.mul(n_stacks_u32, ir.uint_immediate_number(ir.u32_type(), 2u))); + ir.store_variable(st.overflow_flag_word_var, overflow_word_idx); + // Word-offsets inside the bytecode buffer for the nodes and indices arrays. Value header_words_u32 = ir.uint_immediate_number(ir.u32_type(), kHeaderWords); Value stack_header_words_u32 = ir.uint_immediate_number(ir.u32_type(), kStackHeaderWords); diff --git a/quadrants/codegen/spirv/kernel_utils.h b/quadrants/codegen/spirv/kernel_utils.h index 7c12196b0a..874e4cab50 100644 --- a/quadrants/codegen/spirv/kernel_utils.h +++ b/quadrants/codegen/spirv/kernel_utils.h @@ -207,6 +207,7 @@ struct TaskAttributes { // metadata. Aliased to the shared cross-backend struct in `quadrants/transforms/static_adstack_analysis.h`; the // SPIR-V codegen and the LLVM codegen consume the same captured representation through that header. using StaticBoundExpr = ::quadrants::lang::StaticAdStackBoundExpr; + using MaxReducerSpec = ::quadrants::lang::StaticAdStackMaxReducerSpec; struct AdStackSizingAttribs { // Compile-time-derived per-thread strides in elements of each heap's element type. The runtime recomputes these @@ -226,7 +227,16 @@ struct TaskAttributes { // offline cache: ids are assigned per `Program` lifetime; a deserialised task re-registers itself at // the next launch. uint32_t registry_id{0}; - QD_IO_DEF(per_thread_stride_float_compile_time, per_thread_stride_int_compile_time, allocas, bound_expr); + // per-task list of `MaxOverRange` nodes the runtime reduces in parallel via a dedicated max-reducer dispatch + // instead of letting the per-thread sizer enumerate. Empty when no captured `size_expr` contains a recognized + // shape; in that case every `MaxOverRange` falls through to the existing capped path (host: `QD_DEBUG_ADSTACK` + // tripwire; device: silent truncation). + std::vector max_reducer_specs; + QD_IO_DEF(per_thread_stride_float_compile_time, + per_thread_stride_int_compile_time, + allocas, + bound_expr, + max_reducer_specs); }; AdStackSizingAttribs ad_stack; diff --git a/quadrants/codegen/spirv/spirv_codegen.cpp b/quadrants/codegen/spirv/spirv_codegen.cpp index d3be1a9688..4eb215f96b 100644 --- a/quadrants/codegen/spirv/spirv_codegen.cpp +++ b/quadrants/codegen/spirv/spirv_codegen.cpp @@ -10,6 +10,7 @@ #include "quadrants/codegen/codegen_utils.h" #include "quadrants/program/program.h" #include "quadrants/program/kernel.h" +#include "quadrants/program/adstack_size_expr_eval.h" #include "quadrants/ir/statements.h" #include "quadrants/ir/ir.h" #include "quadrants/util/line_appender.h" @@ -202,6 +203,17 @@ TaskCodegen::Result TaskCodegen::run() { task_attribs_.ad_stack.per_thread_stride_float_compile_time = ad_stack_heap_per_thread_stride_float_; task_attribs_.ad_stack.per_thread_stride_int_compile_time = ad_stack_heap_per_thread_stride_int_; + // recognize `MaxOverRange` nodes the runtime can reduce in parallel via the dedicated max-reducer dispatch instead of + // letting the per-thread sizer enumerate. Indexing matches `task_attribs_.ad_stack.allocas` (each entry's `size_expr` + // is the per-stack tree captured above). + { + std::vector per_stack_size_exprs; + per_stack_size_exprs.reserve(task_attribs_.ad_stack.allocas.size()); + for (const auto &a : task_attribs_.ad_stack.allocas) { + per_stack_size_exprs.push_back(a.size_expr); + } + task_attribs_.ad_stack.max_reducer_specs = recognize_adstack_max_reducer_specs(per_stack_size_exprs); + } // Snodes the task body mutates (any `GlobalStore` or `AtomicOp` whose dest resolves to a // `GlobalPtrStmt`). Persisted on `task_attribs_.snode_writes` so the SPIR-V launcher can bump diff --git a/quadrants/ir/static_adstack_max_reducer_device.h b/quadrants/ir/static_adstack_max_reducer_device.h new file mode 100644 index 0000000000..1ce1a23160 --- /dev/null +++ b/quadrants/ir/static_adstack_max_reducer_device.h @@ -0,0 +1,58 @@ +// Device-side parameter blob for the LLVM static-adstack max reducer. The host (`LlvmRuntimeExecutor`) fills this +// struct on each launch with one captured `StaticAdStackMaxReducerSpec`'s dispatch parameters, memcpys it (plus the +// body bytecode trailing blob) into a small device buffer, and calls `runtime_eval_adstack_max_reduce(runtime, ctx, +// params_blob, body_bytecode)` as a single-thread serial function via the LLVM runtime JIT module - mirrors how +// `runtime_eval_static_bound_count` is invoked for the bound reducer. +// +// The body bytecode is a separate pointer because it varies in size per spec (nodes + indices arrays) while the params +// blob has a fixed POD layout. The runtime function walks the cross-product of every axis range, evaluates the body +// subtree against each axis's pre-populated bound variable, tracks the per-launch running max, and writes the result +// into `runtime->adstack_max_reducer_outputs[output_slot]`. The caller substitutes the dispatched value as a `Const` +// into the per-stack `SizeExpr` tree before any of the LLVM eval paths walks it. +// +// Shared with the SPIR-V variant (`AdStackMaxReducerParams` in `quadrants/codegen/spirv/adstack_max_reducer_shader.h`) +// at the field-semantics level: `output_slot`, per-axis lengths / begins carry the same meaning. The LLVM variant +// passes one body bytecode blob per call (vs. the SPIR-V `body_bytecode_offset_words` / `body_node_count` which address +// into a shared bytecode buffer at descriptor-aligned offsets). +#pragma once + +#include + +namespace quadrants::lang { + +// Maximum number of nested `MaxOverRange` axes the recognizer may absorb into a single max-reducer dispatch. The +// recognizer's greedy chain capture (in `recognize_adstack_max_reducer_specs`) walks down nested `MaxOverRange` bodies +// and accumulates one axis per layer; specs whose chain exceeds this cap fall back to the per-thread sizer. Bumping the +// constant raises the per-spec params blob size by 16 bytes/axis on LLVM and 4 words/axis on SPIR-V. Practical +// workloads (Genesis rigid-body kernels) capture 1-3 axes; keep the cap modest so the SPIR-V params blob stays well +// below the descriptor-set min push-constant budget. +constexpr int32_t kAdStackMaxReducerMaxAxes = 8; + +struct LlvmAdStackMaxReducerDeviceParams { + // Slot index in `runtime->adstack_max_reducer_outputs` that this dispatch's running max is written into. Allocated by + // the host launcher per `StaticAdStackMaxReducerSpec` from the same `MaxReducerCacheKey -> output_slot` table the + // SPIR-V launcher uses, so the same numeric slot is consistent across backends. + uint32_t output_slot; + // Number of captured chain axes (1..kAdStackMaxReducerMaxAxes). Axis 0 is the outermost `MaxOverRange`, axis + // `num_axes - 1` is the innermost. + uint32_t num_axes; + // Number of `AdStackSizeExprDeviceNode`s in the body bytecode trailing blob. Bytecode layout: + // `[AdStackSizeExprDeviceNode x body_node_count][int32 x indices_count]`. `indices_count` is implicit in the + // node-side `indices_offset` / `indices_count` fields - the bytecode buffer simply contains the contiguous indices + // table after the nodes. + uint32_t body_node_count; + // Index of the body subtree's root within the body bytecode (post-order encoding places the root last, so + // `body_root_node_idx == body_node_count - 1`). Cached here so the runtime function does not need to subtract. + int32_t body_root_node_idx; + // Per-axis iteration length (`end - begin`), ordered outermost-first. Axes beyond `num_axes` are zero-padded. + uint32_t per_axis_length[kAdStackMaxReducerMaxAxes]; + // Per-axis iteration base (`begin`), ordered outermost-first. The runtime pre-populates + // `scope.values[per_axis_var_id[k]] = per_axis_begin[k] + i_k` for each axis before walking the body. + int64_t per_axis_begin[kAdStackMaxReducerMaxAxes]; + // Per-axis device-scope slot id for the bound variable. Encoded by the host as a dense remap of the captured chain + // bound-var ids into `[0, num_axes)`; the body bytecode encodes references as `-(slot + 1)` (matching the existing + // device-side `device_eval_node` convention). + int32_t per_axis_var_id[kAdStackMaxReducerMaxAxes]; +}; + +} // namespace quadrants::lang diff --git a/quadrants/program/adstack/cache.cpp b/quadrants/program/adstack/cache.cpp new file mode 100644 index 0000000000..f8f3b9aaee --- /dev/null +++ b/quadrants/program/adstack/cache.cpp @@ -0,0 +1,595 @@ +#include "quadrants/program/adstack/cache.h" + +#include +#include +#include +#include +#include +#include + +#include "quadrants/common/logging.h" +#include "quadrants/ir/type.h" +#include "quadrants/ir/type_factory.h" +#include "quadrants/program/adstack/diagnose.h" +#include "quadrants/program/adstack/eval.h" +#include "quadrants/program/launch_context_builder.h" +#include "quadrants/program/program.h" +#include "quadrants/rhi/device.h" + +namespace quadrants::lang { + +namespace { + +// Read the input that `obs` describes against the live state and `ctx`. Caller compares the result to +// `obs.observed_value` to decide whether the cached `SizeExprCacheEntry` is still valid. Each `obs.kind` +// mirrors the corresponding leaf in `evaluate_field_load` / `evaluate_external_tensor_shape` / +// `evaluate_external_tensor_read`. +int64_t replay_one_observation(const AdStackCache::SizeExprReadObservation &obs, + Program *prog, + LaunchContextBuilder *ctx) { + using Obs = AdStackCache::SizeExprReadObservation; + switch (obs.kind) { + case Obs::FieldLoadObs: { + // Gen-counter fast skip: when no kernel has bumped this SNode's write generation since record time, the + // underlying field value cannot have changed and we can return the recorded `observed_value` without dispatching + // a reader kernel. The dispatch is the dominant per-launch cost on the hot path for steady-state reverse-mode + // loops with stable bounds. + if (prog != nullptr && prog->adstack_cache().snode_write_gen(obs.snode_id) == obs.observed_gen) { + return obs.observed_value; + } + // Max-reducer body FieldLoadObs (bound-var-indexed leaves) records `indices = {}` since the body is evaluated at + // every cross-product point and there is no single canonical index to re-read. The gen counter is the only valid + // staleness signal in that mode; a gen mismatch unconditionally invalidates the cache. + if (obs.indices.empty()) { + return obs.observed_value + 1; + } + int64_t v = read_field_with_launch_cache(obs.snode_id, obs.indices, prog); + if (v == std::numeric_limits::min()) { + return obs.observed_value + 1; // force a mismatch if SNode disappeared + } + return v; + } + case Obs::ExternalShapeObs: { + if (ctx == nullptr) { + return obs.observed_value + 1; + } + std::vector arg_indices(obs.arg_id_path.begin(), obs.arg_id_path.end()); + arg_indices.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY); + arg_indices.push_back(obs.arg_shape_axis); + return static_cast(ctx->get_struct_arg_host(arg_indices)); + } + case Obs::ExternalReadObs: { + if (ctx == nullptr || obs.arg_id_path.empty()) { + return obs.observed_value + 1; + } + int arg_id = obs.arg_id_path[0]; + ArgArrayPtrKey key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; + auto it = ctx->array_ptrs.find(key); + if (it == ctx->array_ptrs.end()) { + return obs.observed_value + 1; + } + void *data_ptr = it->second; + // Gen-counter fast skip: when the data pointer is the same `DeviceAllocation *` we observed at record + // time AND its data generation has not been bumped since (no kernel write, no host-side `Ndarray.write` + // / `fill`), the underlying scalar cannot have changed and we can return the recorded value without + // dereferencing the device pointer (which on GPU would be a DtoH copy, on CPU a host load). + if (prog != nullptr && data_ptr == obs.observed_devalloc && + prog->adstack_cache().ndarray_data_gen(data_ptr) == obs.observed_gen) { + return obs.observed_value; + } + int64_t linear = 0; + int64_t stride = 1; + for (std::size_t i = obs.indices.size(); i > 0; --i) { + linear += static_cast(obs.indices[i - 1]) * stride; + if (i - 1 > 0) { + std::vector sh_idx(obs.arg_id_path.begin(), obs.arg_id_path.end()); + sh_idx.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY); + sh_idx.push_back(static_cast(i - 1)); + stride *= static_cast(ctx->get_struct_arg_host(sh_idx)); + } + } + switch (static_cast(obs.prim_dt)) { + case PrimitiveTypeID::i32: + return static_cast(static_cast(data_ptr)[linear]); + case PrimitiveTypeID::i64: + return static_cast(data_ptr)[linear]; + case PrimitiveTypeID::u32: + return static_cast(static_cast(data_ptr)[linear]); + case PrimitiveTypeID::u64: + return static_cast(static_cast(data_ptr)[linear]); + case PrimitiveTypeID::i16: + return static_cast(static_cast(data_ptr)[linear]); + case PrimitiveTypeID::u16: + return static_cast(static_cast(data_ptr)[linear]); + case PrimitiveTypeID::i8: + return static_cast(static_cast(data_ptr)[linear]); + case PrimitiveTypeID::u8: + return static_cast(static_cast(data_ptr)[linear]); + default: + return obs.observed_value + 1; + } + } + } + return obs.observed_value + 1; +} + +} // namespace + +bool AdStackCache::try_size_expr_cache_hit(Program *prog, + const SerializedSizeExpr *expr_key, + LaunchContextBuilder *ctx, + int64_t &out_result) { + auto it = size_expr_cache_.find(expr_key); + if (it == size_expr_cache_.end()) { + return false; + } + const auto &entry = it->second; + for (const auto &obs : entry.reads) { + int64_t now = replay_one_observation(obs, prog, ctx); + if (now != obs.observed_value) { + size_expr_cache_.erase(it); + return false; + } + } + out_result = entry.result; + return true; +} + +void AdStackCache::record_size_expr_eval(const SerializedSizeExpr *expr_key, + int64_t result, + std::vector reads) { + size_expr_cache_[expr_key] = SizeExprCacheEntry{result, std::move(reads)}; +} + +namespace { +// Pack a `(registry_id, stack_id, mor_node_idx)` triple into a 64-bit map key. The recognizer caps both `stack_id` and +// `mor_node_idx` at O(10s) per task (per-task adstack count and per-stack node count are both small), well within 16 +// bits each, so the packed encoding never collides. `registry_id` uses the full 32 bits since the program-side registry +// can grow to thousands of entries across a long-running session. +inline uint64_t pack_max_reducer_key(uint32_t registry_id, int32_t stack_id, int32_t mor_node_idx) { + return (static_cast(registry_id) & 0xFFFFFFFFull) | ((static_cast(stack_id) & 0xFFFFull) << 32) | + ((static_cast(mor_node_idx) & 0xFFFFull) << 48); +} +} // namespace + +bool AdStackCache::try_max_reducer_cache_hit(uint32_t registry_id, + int32_t stack_id, + int32_t mor_node_idx, + LaunchContextBuilder *ctx, + int64_t &out_result) { + auto it = max_reducer_cache_.find(pack_max_reducer_key(registry_id, stack_id, mor_node_idx)); + if (it == max_reducer_cache_.end()) { + return false; + } + const auto &entry = it->second; + for (const auto &obs : entry.reads) { + int64_t now = replay_one_observation(obs, prog_, ctx); + if (now != obs.observed_value) { + max_reducer_cache_.erase(it); + return false; + } + } + out_result = entry.result; + return true; +} + +void populate_max_reducer_body_observations(std::vector &reads, + LaunchContextBuilder *ctx, + AdStackCache *cache) { + for (auto &obs : reads) { + if (obs.kind == AdStackCache::SizeExprReadObservation::FieldLoadObs) { + // `FieldLoadObs` from a bound-var-indexed body leaf: snapshot the snode write generation so a subsequent launch + // that has not mutated the SNode replays the cached max via `replay_one_observation`'s gen-fast-skip arm. Same + // sentinel rationale as `ExternalReadObs` below: the recognizer restricts the leaf dtype so an `INT64_MIN` + // recorded value cannot equal a freshly-loaded one on cache miss. + obs.observed_value = std::numeric_limits::min(); + if (cache != nullptr) { + obs.observed_gen = cache->snode_write_gen(obs.snode_id); + } + continue; + } + if (obs.kind != AdStackCache::SizeExprReadObservation::ExternalReadObs || obs.arg_id_path.empty()) { + continue; + } + if (ctx == nullptr) { + continue; + } + int arg_id = obs.arg_id_path[0]; + ArgArrayPtrKey key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; + auto it = ctx->array_ptrs.find(key); + if (it == ctx->array_ptrs.end()) { + continue; + } + obs.observed_devalloc = it->second; + // Pick an `observed_value` that no in-range ndarray scalar can equal (`INT64_MIN`). The replay code returns + // `obs.observed_value` verbatim when `ndarray_data_gen` still matches the recorded snapshot, so an `INT64_MIN` + // record is a self-equal cache hit. On gen mismatch the replay re-dereferences `data[0]` instead, which (under any + // sub-i64 prim_dt the recognizer admits) widens to an i64 strictly greater than `INT64_MIN` and forces the cache to + // invalidate. The dispatched max itself lives in `MaxReducerCacheEntry::result`; this observation only gates + // whether the cache stays warm. + obs.observed_value = std::numeric_limits::min(); + if (cache != nullptr) { + obs.observed_gen = cache->ndarray_data_gen(it->second); + } + } +} + +const std::vector * +AdStackCache::lookup_max_reducer_reads(uint32_t registry_id, int32_t stack_id, int32_t mor_node_idx) const { + auto it = max_reducer_cache_.find(pack_max_reducer_key(registry_id, stack_id, mor_node_idx)); + if (it == max_reducer_cache_.end()) { + return nullptr; + } + return &it->second.reads; +} + +void AdStackCache::record_max_reducer_eval(uint32_t registry_id, + int32_t stack_id, + int32_t mor_node_idx, + int64_t result, + std::vector reads) { + max_reducer_cache_[pack_max_reducer_key(registry_id, stack_id, mor_node_idx)] = + MaxReducerCacheEntry{result, std::move(reads)}; + ++max_reducer_dispatch_count_; +} + +bool AdStackCache::try_spirv_bytecode_cache_hit(Program *prog, + const void *attribs_key, + LaunchContextBuilder *ctx, + std::vector &out_bytecode) { + auto it = spirv_bytecode_cache_.find(attribs_key); + if (it == spirv_bytecode_cache_.end()) { + return false; + } + const auto &entry = it->second; + for (const auto &obs : entry.reads) { + int64_t now = replay_one_observation(obs, prog, ctx); + if (now != obs.observed_value) { + spirv_bytecode_cache_.erase(it); + return false; + } + } + out_bytecode = entry.bytecode; + return true; +} + +void AdStackCache::record_spirv_bytecode_eval(const void *attribs_key, + std::vector bytecode, + std::vector reads) { + spirv_bytecode_cache_[attribs_key] = SpirvBytecodeCacheEntry{std::move(bytecode), std::move(reads)}; +} + +void AdStackCache::record_per_task_ad_stack(const void *attribs_key, + std::vector metadata, + uint32_t stride_float, + uint32_t stride_int, + std::vector> snode_gens, + std::vector> arg_gens) { + per_task_ad_stack_cache_[attribs_key] = PerTaskAdStackCacheEntry{std::move(metadata), stride_float, stride_int, + std::move(snode_gens), std::move(arg_gens)}; +} + +bool AdStackCache::try_per_task_ad_stack_cache_hit(const void *attribs_key, + LaunchContextBuilder *ctx, + PerTaskAdStackCacheEntry &out) { + auto it = per_task_ad_stack_cache_.find(attribs_key); + if (it == per_task_ad_stack_cache_.end()) { + return false; + } + const auto &entry = it->second; + for (const auto &snode_pair : entry.snode_gens) { + if (snode_write_gen(snode_pair.first) != snode_pair.second) { + per_task_ad_stack_cache_.erase(it); + return false; + } + } + for (const auto &arg_tuple : entry.arg_gens) { + int arg_id = std::get<0>(arg_tuple); + void *recorded_devalloc = std::get<1>(arg_tuple); + uint64_t recorded_gen = std::get<2>(arg_tuple); + void *current_devalloc = nullptr; + if (ctx != nullptr) { + ArgArrayPtrKey key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; + auto ap_it = ctx->array_ptrs.find(key); + if (ap_it != ctx->array_ptrs.end()) { + current_devalloc = ap_it->second; + } + } + if (current_devalloc != recorded_devalloc) { + per_task_ad_stack_cache_.erase(it); + return false; + } + if (ndarray_data_gen(recorded_devalloc) != recorded_gen) { + per_task_ad_stack_cache_.erase(it); + return false; + } + } + out = entry; + return true; +} + +void AdStackCache::record_llvm_per_task_ad_stack(const void *attribs_key, + std::vector offsets, + std::vector max_sizes, + uint64_t stride_combined, + uint64_t stride_float, + uint64_t stride_int, + std::vector> snode_gens, + std::vector> arg_gens) { + llvm_per_task_ad_stack_cache_[attribs_key] = + LlvmPerTaskAdStackCacheEntry{std::move(offsets), std::move(max_sizes), stride_combined, stride_float, + stride_int, std::move(snode_gens), std::move(arg_gens)}; +} + +bool AdStackCache::try_llvm_per_task_ad_stack_cache_hit(const void *attribs_key, + LaunchContextBuilder *ctx, + LlvmPerTaskAdStackCacheEntry &out) { + auto it = llvm_per_task_ad_stack_cache_.find(attribs_key); + if (it == llvm_per_task_ad_stack_cache_.end()) { + return false; + } + const auto &entry = it->second; + for (const auto &snode_pair : entry.snode_gens) { + if (snode_write_gen(snode_pair.first) != snode_pair.second) { + llvm_per_task_ad_stack_cache_.erase(it); + return false; + } + } + for (const auto &arg_tuple : entry.arg_gens) { + int arg_id = std::get<0>(arg_tuple); + void *recorded_devalloc = std::get<1>(arg_tuple); + uint64_t recorded_gen = std::get<2>(arg_tuple); + void *current_devalloc = nullptr; + if (ctx != nullptr) { + ArgArrayPtrKey key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; + auto ap_it = ctx->array_ptrs.find(key); + if (ap_it != ctx->array_ptrs.end()) { + current_devalloc = ap_it->second; + } + } + if (current_devalloc != recorded_devalloc) { + llvm_per_task_ad_stack_cache_.erase(it); + return false; + } + if (ndarray_data_gen(recorded_devalloc) != recorded_gen) { + llvm_per_task_ad_stack_cache_.erase(it); + return false; + } + } + out = entry; + return true; +} + +uint32_t AdStackCache::register_adstack_sizing_info(const void *identity_key, + const std::string &kernel_name, + int task_id_in_kernel, + std::vector allocated_max_sizes, + std::vector size_exprs) { + std::lock_guard lk(adstack_sizing_info_registry_mutex_); + // Idempotent re-registration: same `identity_key` yields the same id across re-compiles and updates the + // entry's metadata + size_exprs in place. The key is just an opaque dedup token - the registry never + // dereferences it; all data needed by the diagnose path is copied into the entry below. + auto it = adstack_sizing_info_id_by_ptr_.find(identity_key); + if (it != adstack_sizing_info_id_by_ptr_.end()) { + auto &entry = adstack_sizing_info_registry_[it->second]; + entry.kernel_name = kernel_name; + entry.task_id_in_kernel = task_id_in_kernel; + entry.allocated_max_sizes = std::move(allocated_max_sizes); + entry.size_exprs = std::move(size_exprs); + return it->second; + } + uint32_t id = static_cast(adstack_sizing_info_registry_.size()); + AdStackSizingInfoEntry entry; + entry.identity_key = identity_key; + entry.kernel_name = kernel_name; + entry.task_id_in_kernel = task_id_in_kernel; + entry.allocated_max_sizes = std::move(allocated_max_sizes); + entry.size_exprs = std::move(size_exprs); + adstack_sizing_info_registry_.push_back(std::move(entry)); + adstack_sizing_info_id_by_ptr_.emplace(identity_key, id); + return id; +} + +void AdStackCache::update_adstack_sizing_info_size_exprs(uint32_t id, std::vector size_exprs) { + std::lock_guard lk(adstack_sizing_info_registry_mutex_); + if (id == 0 || id >= adstack_sizing_info_registry_.size()) { + return; + } + adstack_sizing_info_registry_[id].size_exprs = std::move(size_exprs); +} + +std::optional AdStackCache::lookup_adstack_sizing_info(uint32_t id) const { + std::lock_guard lk(adstack_sizing_info_registry_mutex_); + if (id == 0 || id >= adstack_sizing_info_registry_.size()) { + return std::nullopt; + } + return adstack_sizing_info_registry_[id]; +} + +std::string AdStackCache::diagnose_adstack_overflow_message(uint32_t task_id) const { + return diagnose_adstack_overflow(task_id).message; +} + +AdStackCache::AdStackOverflowDiagnosis AdStackCache::diagnose_adstack_overflow(uint32_t task_id) const { + // Lazy LLVM capture: if the launcher stashed a pending ctx pointer for this launch (LLVM defers eager + // capture to avoid the per-launch snapshot cost), capture now before walking size_exprs. SPIR-V already + // captured eagerly at launch, so `pending_launch_ctx_` is null there. + if (pending_launch_ctx_ != nullptr) { + const_cast(this)->capture_diagnose_snapshot(*pending_launch_ctx_); + } + std::string identity_block; + std::string disambiguation_block; + // Cause classifier: when the synchronous re-run produces required > allocated for ANY stack, the most likely + // cause is an untracked tensor mutation (DLPack-bypass etc.). When all required <= allocated, the pre-pass + // undersized the bound (Quadrants bug). When we cannot re-evaluate (e.g. no captured launch snapshot, or a + // leaf type the diagnose evaluator does not support) we fall through to the static dual-cause body. + enum class Cause { Unknown, DLPackBypass, QuadrantsBug }; + Cause cause = Cause::Unknown; + + if (task_id != 0) { + auto entry_opt = lookup_adstack_sizing_info(task_id); + if (entry_opt.has_value()) { + const auto &entry = *entry_opt; + identity_block = " Offending task: kernel `" + entry.kernel_name + "` offload task #" + + std::to_string(entry.task_id_in_kernel) + "; per-stack allocated max_size = ["; + for (size_t i = 0; i < entry.allocated_max_sizes.size(); ++i) { + if (i != 0) { + identity_block += ", "; + } + identity_block += std::to_string(entry.allocated_max_sizes[i]); + } + identity_block += "].\n"; + + // Synchronous sizer rerun: walk each stack's `SerializedSizeExpr` and evaluate against the live host / + // SNode state. Stacks whose tree contains an `ExternalTensorShape` or `ExternalTensorRead` leaf go + // through the snapshot-based `evaluate_adstack_size_expr_for_diagnose` (see its declaration for the + // `Device::map` design rationale). Pure host-resolvable trees go through the standard host evaluator. + // The disambiguation is best-effort: if every stack's tree resolves we get a precise classification; + // otherwise we report what we have and fall back to the static dual-cause hint. + if (!entry.size_exprs.empty()) { + std::vector required_sizes; + std::vector required_known; + size_t any_grew = 0; + size_t any_unknown = 0; + size_t total = std::min(entry.size_exprs.size(), entry.allocated_max_sizes.size()); + for (size_t i = 0; i < total; ++i) { + const auto &expr = entry.size_exprs[i]; + bool host_resolvable = true; + for (const auto &node : expr.nodes) { + auto k = static_cast(node.kind); + if (k == SizeExpr::Kind::ExternalTensorShape || k == SizeExpr::Kind::ExternalTensorRead) { + host_resolvable = false; + break; + } + } + int64_t v = -1; + if (host_resolvable && !expr.nodes.empty()) { + // Pure host-resolvable: SNode field loads, constants, arithmetic. `ctx == nullptr` is safe because + // every leaf we kept is host-resolvable; ETS / ETR are the only kinds that touch ctx and we + // filtered them out. + SizeExprLaunchScope scope; + v = evaluate_adstack_size_expr(expr, prog_, nullptr); + } else if (!expr.nodes.empty()) { + // Tree contains ETR / ETS leaves. The diagnose evaluator resolves them through the captured launch + // snapshot (`Device::map`-based ndarray reads). On failure (no snapshot, allocation cannot be + // mapped, unsupported dtype) the helper returns -1 and we fall through to the `?` placeholder. + int64_t diag = evaluate_adstack_size_expr_for_diagnose(expr, prog_); + if (diag >= 0) { + v = diag; + } + } + required_sizes.push_back(v); + required_known.push_back(!expr.nodes.empty() && v >= 0); + if (required_known.back() && static_cast(v) > entry.allocated_max_sizes[i]) { + ++any_grew; + } + if (!required_known.back()) { + ++any_unknown; + } + } + if (any_grew > 0) { + cause = Cause::DLPackBypass; + } else if (any_unknown == 0 && total > 0) { + cause = Cause::QuadrantsBug; + } + // Only print the rerun line when at least one stack's bound resolves to a real value. With every leaf + // unresolved the line would be `required = [?, ?, ...]` which adds zero signal beyond the dual-cause + // body that follows; the omission keeps the message focused on actionable content. + if (any_unknown < total) { + disambiguation_block = " Synchronous sizer rerun: required max_size = ["; + for (size_t i = 0; i < required_sizes.size(); ++i) { + if (i != 0) { + disambiguation_block += ", "; + } + if (required_known[i]) { + disambiguation_block += std::to_string(required_sizes[i]); + } else { + disambiguation_block += "?"; + } + } + disambiguation_block += "]."; + if (any_unknown > 0) { + disambiguation_block += + " (`?` = sizer rerun could not resolve this stack's bound against the captured " + "launch state)."; + } + disambiguation_block += "\n"; + } + } + } + } + + std::string body; + if (cause == Cause::DLPackBypass) { + body = + "Cause (sync sizer rerun): a tensor backing a data-dependent loop bound was mutated outside " + "Quadrants's tracking - typically a DLPack zero-copy mutation through a torch tensor sharing " + "storage with a Quadrants ndarray, or a raw pointer write through a non-torch DLPack consumer. " + "The cached adstack capacity was sized against the value before the mutation. Recovery: route " + "the mutation through Quadrants APIs (`Ndarray.write` / `fill` / kernel writes) so the cache " + "invalidates correctly, OR set a generous initial cap if a workload-change milestone genuinely " + "grew capacity. Restart the iteration / training loop from a clean state.\n"; + } else if (cause == Cause::QuadrantsBug) { + body = + "Cause (sync sizer rerun): the freshly-computed required size does not exceed the allocated " + "size for any stack - this is a Quadrants bug. The pre-pass resolved the alloca to a bound " + "tighter than the actual runtime push count: either the enclosing loop shape is outside the " + "current `SizeExpr` grammar, or the Bellman-Ford analyzer undercounted the forward-pass " + "accumulation. Please file with the kernel IR (`QD_DUMP_IR=1`).\n"; + } else { + body = + "Two possible causes (synchronous sizer rerun was not conclusive - some `SizeExpr` trees " + "depend on ndarray contents that are not host-resolvable without a per-launch context, or the " + "task-id slot was empty so the registry pointer could not be confirmed live):\n" + " 1. A tensor backing a data-dependent loop bound was mutated outside Quadrants's tracking " + "(typically a DLPack zero-copy mutation through a torch tensor sharing storage with a " + "Quadrants ndarray, or a raw pointer write through a non-torch DLPack consumer). The cached " + "adstack capacity was sized against the value before the mutation. Recovery: route the " + "mutation through Quadrants APIs (`Ndarray.write` / `fill` / kernel writes) so the cache " + "invalidates correctly, OR set a generous initial cap if a workload-change milestone " + "genuinely grew capacity. Restart the iteration / training loop from a clean state.\n" + " 2. (Quadrants bug) the pre-pass resolved the alloca to a bound tighter than the actual " + "runtime push count - the enclosing loop shape is outside the current `SizeExpr` grammar, or " + "the Bellman-Ford analyzer undercounted the forward-pass accumulation. Please file with the " + "kernel IR (`QD_DUMP_IR=1`).\n"; + } + AdStackOverflowDiagnosis result; + result.message = identity_block + disambiguation_block + body + + "Note: kernel state may be inconsistent post-overflow; do not retry the same " + "step without addressing the cause and restarting from a clean state."; + // Flag the cache as confirmed-invalid only when the sync rerun positively identified DLPack-bypass (`required + // > allocated` for at least one stack with every leaf resolved against the live snapshot). Unknown is a rare + // fallback now that the snapshot-based evaluator handles ndarray-bound leaves; treating it as + // confirmed-bypass would silently retry against a possibly-broken cache. Quadrants-bug is excluded for the + // same reason - the next launch would re-run the same wrong sizer and produce the same wrong bound. + result.confirmed_invalid_cache = (cause == Cause::DLPackBypass); + return result; +} + +void AdStackCache::capture_diagnose_snapshot(const LaunchContextBuilder &ctx) { + diagnose_snapshot_.data_ptrs.clear(); + diagnose_snapshot_.dev_alloc_types.clear(); + diagnose_snapshot_.shapes.clear(); + // Pull just the data-pointer slot for each arg; the grad-pointer slot is irrelevant to size_expr leaves. + for (const auto &kv : ctx.array_ptrs) { + if (kv.first.ptr_type == TypeFactory::DATA_PTR_POS_IN_NDARRAY) { + diagnose_snapshot_.data_ptrs[kv.first.arg_id] = kv.second; + } + } + diagnose_snapshot_.dev_alloc_types = ctx.device_allocation_type; + // Mirror the per-arg shape vectors `LaunchContextBuilder` populated alongside the args-buffer writes. Going + // through this side map rather than `args_type->get_element_offset` avoids the spurious "Cannot treat as + // TensorType" diagnostics emitted when an axis lookup overruns the actual rank, and keeps the diagnose path + // independent of `args_type` lifetime. + for (const auto &kv : ctx.ndarray_shapes) { + std::vector shape32(kv.second.begin(), kv.second.end()); + diagnose_snapshot_.shapes[kv.first] = std::move(shape32); + } + diagnose_snapshot_.valid = true; +} + +const AdStackCache::DiagnoseLaunchSnapshot *AdStackCache::get_diagnose_snapshot() const { + return diagnose_snapshot_.valid ? &diagnose_snapshot_ : nullptr; +} + +} // namespace quadrants::lang diff --git a/quadrants/program/adstack/cache.h b/quadrants/program/adstack/cache.h new file mode 100644 index 0000000000..33024dbfaf --- /dev/null +++ b/quadrants/program/adstack/cache.h @@ -0,0 +1,355 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "quadrants/ir/adstack_size_expr.h" +#include "quadrants/program/launch_context_builder.h" +#include "quadrants/program/program.h" + +namespace quadrants::lang { + +class LaunchContextBuilder; +class Program; + +// Adstack-specific state owned by `Program` and routed through `program->adstack_cache().method(...)`. Holds two +// orthogonal pieces: +// 1. The per-task adstack-sizer metadata caches (SPIR-V + LLVM-GPU), the encoded SPIR-V bytecode cache, the +// per-launch SizeExpr-eval result cache, and the per-snode / per-DeviceAllocation generation counters that +// drive precise invalidation. +// 2. The adstack-overflow identity registry + diagnostic classifier that the codegen-emitted overflow path +// reads through (`Program::launch_kernel` populates `DiagnoseLaunchSnapshot`; the registry maps task ids +// to kernel + offload-task identities + per-stack capacities, and `diagnose_adstack_overflow` runs the +// synchronous sizer rerun against the captured snapshot to classify the failure mode). +// Both pieces are adstack-internal and lived in `Program` historically; consolidating them here keeps the +// `Program` surface focused on cross-feature program state. +class AdStackCache { + public: + // Back-reference to `Program` is used by the diagnose path to reach `evaluate_adstack_size_expr` / + // `evaluate_adstack_size_expr_for_diagnose` (free functions that take `Program *`) and by the registry methods + // to access `get_compute_device()` for `Device::map`-based ndarray reads. Stored as a raw pointer because + // `AdStackCache` is owned by `Program` and shares its lifetime - the back-ref cannot dangle. + explicit AdStackCache(Program *prog) : prog_(prog) { + } + + // One input read observed during a `evaluate_adstack_size_expr` walk. The cache entry records these so a subsequent + // lookup re-reads the same inputs and compares to `observed_value`; a single mismatch forces a full re-walk. + // `observed_gen` snapshots `snode_write_gen` (FieldLoadObs) or `ndarray_data_gen` (ExternalReadObs) at record + // time. The replay walk uses it as a fast-path short-circuit: if the gen counter has not advanced, the value + // cannot have changed and the dispatch (reader kernel for SNode reads, device-pointer deref for ndarray reads) + // is skipped. ExternalShapeObs reads the args buffer per launch (cheap host memory access), so it does not need + // a gen and leaves this field at 0. + struct SizeExprReadObservation { + enum Kind : uint8_t { FieldLoadObs, ExternalShapeObs, ExternalReadObs }; + Kind kind; + int snode_id; + std::vector indices; + std::vector arg_id_path; + int arg_shape_axis; + int prim_dt; + int64_t observed_value; + uint64_t observed_gen{0}; + void *observed_devalloc{nullptr}; + }; + struct SizeExprCacheEntry { + int64_t result; + std::vector reads; + }; + bool try_size_expr_cache_hit(Program *prog, + const SerializedSizeExpr *expr_key, + LaunchContextBuilder *ctx, + int64_t &out_result); + void record_size_expr_eval(const SerializedSizeExpr *expr_key, + int64_t result, + std::vector reads); + void invalidate_size_expr() { + size_expr_cache_.clear(); + } + + // Cache for encoded SPIR-V adstack-sizer bytecode. Same dep-tracking contract as `try_size_expr_cache_hit` but the + // cached payload is the encoded bytes rather than an integer. + struct SpirvBytecodeCacheEntry { + std::vector bytecode; + std::vector reads; + }; + bool try_spirv_bytecode_cache_hit(Program *prog, + const void *attribs_key, + LaunchContextBuilder *ctx, + std::vector &out_bytecode); + void record_spirv_bytecode_eval(const void *attribs_key, + std::vector bytecode, + std::vector reads); + void invalidate_spirv_bytecode() { + spirv_bytecode_cache_.clear(); + } + + // Per-task adstack metadata output cache for the SPIR-V on-device sizer. + struct PerTaskAdStackCacheEntry { + std::vector metadata; + uint32_t stride_float{0}; + uint32_t stride_int{0}; + std::vector> snode_gens; + std::vector> arg_gens; + }; + bool try_per_task_ad_stack_cache_hit(const void *attribs_key, + LaunchContextBuilder *ctx, + PerTaskAdStackCacheEntry &out); + void record_per_task_ad_stack(const void *attribs_key, + std::vector metadata, + uint32_t stride_float, + uint32_t stride_int, + std::vector> snode_gens, + std::vector> arg_gens); + void invalidate_per_task_ad_stack() { + per_task_ad_stack_cache_.clear(); + } + + // Per-task adstack metadata output cache for the LLVM-GPU on-device sizer (CUDA + AMDGPU). + struct LlvmPerTaskAdStackCacheEntry { + std::vector offsets; + std::vector max_sizes; + uint64_t stride_combined{0}; + uint64_t stride_float{0}; + uint64_t stride_int{0}; + std::vector> snode_gens; + std::vector> arg_gens; + }; + bool try_llvm_per_task_ad_stack_cache_hit(const void *attribs_key, + LaunchContextBuilder *ctx, + LlvmPerTaskAdStackCacheEntry &out); + void record_llvm_per_task_ad_stack(const void *attribs_key, + std::vector offsets, + std::vector max_sizes, + uint64_t stride_combined, + uint64_t stride_float, + uint64_t stride_int, + std::vector> snode_gens, + std::vector> arg_gens); + void invalidate_llvm_per_task_ad_stack() { + llvm_per_task_ad_stack_cache_.clear(); + } + + // Per-spec output cache for the max reducer. Keyed by `(registry_id, stack_id, mor_node_idx)` packed into a 64-bit + // key (low 32 bits = `registry_id`, mid 16 bits = `stack_id`, high 16 bits = `mor_node_idx`). The recognizer caps + // both `stack_id` and `mor_node_idx` well below 2^16 (per-task adstack count and per-stack node count are both + // O(10s)), so the packing is collision-free. Same observation-walk dependency tracking as `try_size_expr_cache_hit`: + // entries record the body's `ExternalTensorRead` reads plus the `begin` / `end` subtree's leaves; the next launch + // re-walks observations and short-circuits on a generation match. + struct MaxReducerCacheEntry { + int64_t result; + std::vector reads; + }; + bool try_max_reducer_cache_hit(uint32_t registry_id, + int32_t stack_id, + int32_t mor_node_idx, + LaunchContextBuilder *ctx, + int64_t &out_result); + void record_max_reducer_eval(uint32_t registry_id, + int32_t stack_id, + int32_t mor_node_idx, + int64_t result, + std::vector reads); + void invalidate_max_reducer() { + max_reducer_cache_.clear(); + } + // Read-only accessor for the observations recorded for a captured spec. Returns `nullptr` when the spec is not + // currently in the cache. Used by the bytecode encoder to thread the max-reducer body reads into the + // `spirv_bytecode_cache_` entry's observation list, so a mutation to the gating ndarray invalidates the bytecode + // cache (the encoder walks the post-substitution tree where the body is already a `Const` and would otherwise miss + // the underlying ndarray dependency). + const std::vector *lookup_max_reducer_reads(uint32_t registry_id, + int32_t stack_id, + int32_t mor_node_idx) const; + // Monotone counter, incremented once per `record_max_reducer_eval` call. Reset only by the surrounding test harness + // via `reset_max_reducer_dispatch_count`. Used by the regression tests to pin the cache short-circuit: a second + // launch with unchanged inputs must not advance the counter, and a host mutation must. + uint64_t max_reducer_dispatch_count() const { + return max_reducer_dispatch_count_; + } + void reset_max_reducer_dispatch_count() { + max_reducer_dispatch_count_ = 0; + } + + // Bulk-invalidate just the per-task adstack metadata caches on the overflow raise path. The `size_expr_cache_` and + // `spirv_bytecode_cache_` are intentionally NOT cleared: they self-validate via per-read observation walks on the + // next lookup, so a DLPack-bypass mutation surfaces there as a normal observation mismatch and triggers a fresh + // evaluation without explicit eviction. The per-task metadata caches need a force-drop because their gen-counter + // snapshots match when the user's mutation bypassed our tracking. Invalidation is bulk (every task) rather than + // targeted (just the offender) because a single shared DLPack / torch view can back multiple tasks in the same kernel + // queue: targeted invalidation would let the next launch hit a stale entry on a different task that reads the same + // now-mutated tensor and overflow again. Also evicts the max-reducer cache so a stale-cache overflow auto-recovers + // across all four cache layers. + void invalidate_all_per_task() { + invalidate_per_task_ad_stack(); + invalidate_llvm_per_task_ad_stack(); + invalidate_max_reducer(); + } + + uint64_t snode_write_gen(int snode_id) const { + auto it = snode_write_gen_.find(snode_id); + return it == snode_write_gen_.end() ? 0u : it->second; + } + void bump_snode_write_gen(int snode_id) { + ++snode_write_gen_[snode_id]; + } + uint64_t ndarray_data_gen(void *devalloc_ptr) const { + auto it = ndarray_data_gen_.find(devalloc_ptr); + return it == ndarray_data_gen_.end() ? 0u : it->second; + } + void bump_ndarray_data_gen(void *devalloc_ptr) { + ++ndarray_data_gen_[devalloc_ptr]; + } + // Drop a per-DeviceAllocation entry. Called from `Ndarray::~Ndarray()` so the holder address can be reused by a + // future allocation without inheriting the destroyed ndarray's stale generation. Leftover snapshots in + // `per_task_ad_stack_cache_` / `llvm_per_task_ad_stack_cache_` referencing the dropped key fall back to gen=0 + // on the next lookup (their stored snapshot will not match), which forces a fresh sizer dispatch and self-heals. + void erase_ndarray_data_gen(void *devalloc_ptr) { + ndarray_data_gen_.erase(devalloc_ptr); + } + + // ----------------------------------------------------------------------------------------------------------- + // Adstack-overflow identity registry + diagnostic classifier + // ----------------------------------------------------------------------------------------------------------- + // Codegen registers each `OffloadedTask::ad_stack` once per kernel compilation and bakes the assigned id as + // an immediate into the lazy-claim overflow path; on overflow the codegen emits `cmpxchg(0, id)` against the + // pinned-host task-id slot. The host raise site reads the slot and routes through + // `diagnose_adstack_overflow_message(id)` to look up the kernel name, task index, and per-stack metadata for + // an enriched error message. Pointer ownership stays with `OffloadedTask`; entries are added but not removed + // - the registry size is bounded by the number of adstack-bearing tasks compiled in the program's lifetime, + // typically dozens. The diagnose path NEVER dereferences `identity_key`; all size-expression data is stored + // inline (`size_exprs`) so the entry is self-contained and immune to lifetime issues from the underlying + // `AdStackSizingInfo` (LLVM) / `AdStackSizingAttribs` (SPIR-V) struct moves. + struct AdStackSizingInfoEntry { + const void *identity_key{nullptr}; + std::string kernel_name; + int task_id_in_kernel{0}; + std::vector allocated_max_sizes; + std::vector size_exprs; + }; + uint32_t register_adstack_sizing_info(const void *identity_key, + const std::string &kernel_name, + int task_id_in_kernel, + std::vector allocated_max_sizes, + std::vector size_exprs); + // Refresh just the `size_exprs` snapshot in an existing registry entry. Used by the LLVM launcher on the first + // launch of a task whose codegen-time registration could not capture size_exprs (the codegen-time + // `current_task->ad_stack` had not yet been finalized). No-op for `id == 0` and ids outside the registry range. + void update_adstack_sizing_info_size_exprs(uint32_t id, std::vector size_exprs); + // Returns a *copy* of the registry entry (not a pointer into the underlying vector) so the caller can safely + // hold the data across operations that might trigger another `register_adstack_sizing_info` and grow / reallocate + // the registry vector (e.g. `evaluate_adstack_size_expr` dispatching a reader kernel that compiles a fresh + // task). Returns `std::nullopt` for the sentinel id `0` and for out-of-range ids. + std::optional lookup_adstack_sizing_info(uint32_t id) const; + // Format a diagnostic message for an overflow signal. `task_id` is the value read from the pinned-host task-id + // slot (0 if no thread overflowed; otherwise the registry id of the first overflowing task). The `message` + // field is embedded into the `QuadrantsAssertionError` raised at the poll site. The `confirmed_invalid_cache` + // field is true only when the synchronous sizer rerun classified the failure as a stale-cache / + // DLPack-bypass case (`required > allocated` for at least one stack with every leaf resolved against the + // captured launch snapshot); the caller (LLVM `check_adstack_overflow` / SPIR-V `GfxRuntime::synchronize`) + // uses it to decide whether to bulk-invalidate the per-task metadata caches so the next launch auto-recovers. + // We deliberately do NOT invalidate on Unknown / Quadrants-bug because invalidating would mask sizer bugs and + // could let a never-confirmed cause silently retry against a possibly-broken cache. + struct AdStackOverflowDiagnosis { + std::string message; + bool confirmed_invalid_cache{false}; + }; + AdStackOverflowDiagnosis diagnose_adstack_overflow(uint32_t task_id) const; + // Convenience wrapper that returns just the message string; production code uses `diagnose_adstack_overflow` + // to also act on the confirmed-cause signal. + std::string diagnose_adstack_overflow_message(uint32_t task_id) const; + + // Snapshot of the most recent launch's context fields needed by `diagnose_adstack_overflow` to resolve + // ndarray-bound `SizeExpr` leaves (`ExternalTensorRead` / `ExternalTensorShape`) at error time, when the + // original `LaunchContextBuilder` is gone. Captured at the top of `Program::launch_kernel` BEFORE the + // launcher rewrites `array_ptrs` (the CPU launcher's `set_host_accessible_ndarray_ptrs` overwrites the + // `DeviceAllocation *` entry with a raw host pointer; capturing earlier keeps the original handle so the + // diagnose path can use the unified `Device::map` API instead of trusting backend-specific semantics). + // + // Design choice (vs. re-dispatching the on-device sizer at diagnose time): `Device::map` is virtual on + // every backend (CPU / CUDA / AMDGPU / Vulkan / Metal), so this snapshot-plus-map approach gets backend + // parity for free without re-entering the launcher's pipeline-setup machinery (compute pipelines / + // descriptor sets / command buffers / sync fences). The diagnose path stays out of the launch lifecycle. + struct DiagnoseLaunchSnapshot { + bool valid{false}; + // arg_id -> ctx->array_ptrs[(arg_id, DATA_PTR_POS_IN_NDARRAY)]. For `kNone` numpy passthrough this is a + // raw host pointer. For `kNdarray` (qd.ndarray) this is a `DeviceAllocation *` handle the diagnose path + // dereferences via `Device::map`. Captured before the CPU launcher's `set_host_accessible_ndarray_ptrs` + // overwrite so the handle is uniform across backends. + std::unordered_map data_ptrs; + std::unordered_map dev_alloc_types; + // Pre-extracted ndarray shapes (`ctx->get_struct_arg_host({arg_id, SHAPE_POS, axis})`) so the + // diagnose evaluator does not need a live `LaunchContextBuilder` to resolve `ExternalTensorShape` or + // multi-axis `ExternalTensorRead` strides. + std::unordered_map> shapes; + }; + // Capture the per-launch fields the diagnose evaluator needs (see `DiagnoseLaunchSnapshot`'s definition for + // the design rationale and field-by-field semantics). Called eagerly from `Program::launch_kernel` only on + // backends where the launch ctx is gone by the time overflow is detected (SPIR-V at `synchronize`); on LLVM + // backends the per-launch overflow poll runs while ctx is still in scope, so we stash the ctx pointer with + // `set_pending_launch_ctx` and let `diagnose_adstack_overflow` capture lazily on the (rare) overflow path. + void capture_diagnose_snapshot(const LaunchContextBuilder &ctx); + // Lazy-capture handoff: `Program::launch_kernel` on LLVM backends sets this to the in-scope ctx before + // forwarding into the launcher and clears it after the per-launch overflow poll returns. If the poll fires, + // `diagnose_adstack_overflow` reads the pointer and captures the snapshot just in time. Stored as a raw + // pointer because it is transient per-launch and never outlives the call frame that set it. + void set_pending_launch_ctx(const LaunchContextBuilder *ctx) { + pending_launch_ctx_ = ctx; + } + // Read-only accessor for the latest snapshot, used by `diagnose_adstack_overflow` to resolve ndarray-bound + // size_expr leaves. Returns `nullptr` when no launch has happened yet (e.g. a freshly constructed `Program` + // hits `synchronize` during teardown without a prior kernel launch). + const DiagnoseLaunchSnapshot *get_diagnose_snapshot() const; + + private: + Program *prog_{nullptr}; + std::unordered_map size_expr_cache_; + std::unordered_map spirv_bytecode_cache_; + std::unordered_map per_task_ad_stack_cache_; + std::unordered_map llvm_per_task_ad_stack_cache_; + // Max-reducer per-spec output cache. Key encoding: low 32 bits = `registry_id`, mid 16 bits = `stack_id`, high 16 + // bits = `mor_node_idx`. See `try_max_reducer_cache_hit` for the contract and `pack_max_reducer_key` in + // `adstack_size_expr_eval.cpp` for the packing helper. + std::unordered_map max_reducer_cache_; + // See `max_reducer_dispatch_count` for the contract. Bumped at every `record_max_reducer_eval` call (i.e. once per + // cache miss that fired a real dispatch); cache hits do not bump it. + uint64_t max_reducer_dispatch_count_{0}; + std::unordered_map snode_write_gen_; + std::unordered_map ndarray_data_gen_; + + // Adstack-overflow identity registry storage. Index 0 is reserved as the "no overflow" sentinel so the + // codegen-emitted `cmpxchg(0, id)` cleanly distinguishes "task id recorded" from "slot still clean". The + // reverse lookup map (keyed by `identity_key`) keeps `register_adstack_sizing_info` idempotent across + // re-launches of the same kernel. + std::vector adstack_sizing_info_registry_{AdStackSizingInfoEntry{}}; + std::unordered_map adstack_sizing_info_id_by_ptr_; + mutable std::mutex adstack_sizing_info_registry_mutex_; + // Latest captured launch context snapshot for the diagnose path's ndarray-bound leaf resolution. See + // `DiagnoseLaunchSnapshot`'s comment above for why we capture in `Program::launch_kernel` before the launcher + // forwards. + // Single-threaded by construction: `capture_diagnose_snapshot` runs from `Program::launch_kernel` (Python + // launcher thread) and `get_diagnose_snapshot` runs from `diagnose_adstack_overflow` on the same thread; no + // mutex needed. The codegen-time identity registry above keeps its mutex because it is hit from compilation + // worker threads. + DiagnoseLaunchSnapshot diagnose_snapshot_; + // Transient ctx handoff for the lazy LLVM capture path. See `set_pending_launch_ctx`. + const LaunchContextBuilder *pending_launch_ctx_{nullptr}; +}; + +// Snapshot the live ndarray data pointer + generation counter into each `ExternalReadObs` record. The encoder emits the +// observation skeleton (kind / arg_id_path / prim_dt) but cannot fill in the runtime-resolved `data_ptr` / +// `observed_gen` / `observed_value` because it has no `LaunchContextBuilder`. This helper closes that gap right before +// the max-reducer dispatch site calls `AdStackCache::record_max_reducer_eval`, so the next launch's +// `try_max_reducer_cache_hit` replay can fast-skip on a matching `ndarray_data_gen`. `observed_value` is recorded as +// `INT64_MIN` so the replay's gen-mismatch dereference path returns a value strictly greater than the recorded sentinel +// and forces the cache to invalidate; the cached max itself is stored in `MaxReducerCacheEntry::result`, not in any +// per-leaf observation. +void populate_max_reducer_body_observations(std::vector &reads, + LaunchContextBuilder *ctx, + AdStackCache *cache); + +} // namespace quadrants::lang diff --git a/quadrants/program/adstack/device_bytecode.cpp b/quadrants/program/adstack/device_bytecode.cpp new file mode 100644 index 0000000000..ee682a58f2 --- /dev/null +++ b/quadrants/program/adstack/device_bytecode.cpp @@ -0,0 +1,782 @@ +#include "quadrants/program/adstack/device_bytecode.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "quadrants/codegen/spirv/adstack_sizer_shader.h" +#include "quadrants/common/logging.h" +#include "quadrants/ir/adstack_size_expr_device.h" +#include "quadrants/ir/snode.h" +#include "quadrants/ir/type.h" +#include "quadrants/ir/type_factory.h" +#include "quadrants/ir/type_utils.h" +#include "quadrants/program/adstack/eval.h" +#include "quadrants/program/adstack/max_reducer.h" +#include "quadrants/program/launch_context_builder.h" +#include "quadrants/program/program.h" +#include "quadrants/rhi/device.h" + +namespace quadrants::lang { + +namespace { + +using ReadSink = std::vector; + +// -------------------------------------------------------------------------------------------------------------- +// Device-bytecode encoder helpers +// -------------------------------------------------------------------------------------------------------------- + +// `contains_device_leaf[i]` is true when subtree rooted at node `i` has at least one leaf that MUST stay on the +// device during encoding (the host fold cannot substitute it with a `Const`). On the LLVM path this is any +// `ExternalTensorRead` leaf - `FieldLoad` can be host-folded via `SNodeRwAccessorsBank::read_int`, which is safe +// on CPU / CUDA / AMDGPU. On the SPIR-V path the caller flips `fieldload_stays_on_device` to true because on +// MoltenVK a nested `read_int` submit crashes inside the descriptor-set bind path; keeping `FieldLoad` on the +// device side (via PSB loads in the sizer shader) avoids that entirely. Computed bottom-up; `SerializedSizeExpr` +// is already in post-order so every operand / body index is < i. +std::vector compute_contains_device_leaf(const SerializedSizeExpr &expr, bool fieldload_stays_on_device) { + std::vector result(expr.nodes.size(), false); + for (std::size_t i = 0; i < expr.nodes.size(); ++i) { + const auto &node = expr.nodes[i]; + auto kind = static_cast(node.kind); + bool hit = (kind == SizeExpr::Kind::ExternalTensorRead) || + (fieldload_stays_on_device && kind == SizeExpr::Kind::FieldLoad); + if (!hit && node.operand_a >= 0) + hit = result[node.operand_a]; + if (!hit && node.operand_b >= 0) + hit = result[node.operand_b]; + if (!hit && node.body_node_idx >= 0) + hit = result[node.body_node_idx]; + result[i] = hit; + } + return result; +} + +// `free_vars[i]` is the set of `BoundVariable::var_id`s referenced inside subtree(i) but NOT bound by any +// `MaxOverRange` inside that same subtree. An empty set means the subtree is closed and can be evaluated on the +// host without an outer-iteration context. `FieldLoad` / `ExternalTensorRead` index slots use the same +// `-(var_id + 1)` encoding as `BoundVariable` and are accounted for here. +std::vector> compute_free_vars(const SerializedSizeExpr &expr) { + std::vector> result(expr.nodes.size()); + for (std::size_t i = 0; i < expr.nodes.size(); ++i) { + const auto &node = expr.nodes[i]; + auto &fv = result[i]; + auto collect_idx_vars = [&](const std::vector &indices) { + for (int32_t raw : indices) { + if (raw < 0) + fv.insert(-(raw + 1)); + } + }; + switch (static_cast(node.kind)) { + case SizeExpr::Kind::Const: + case SizeExpr::Kind::ExternalTensorShape: + break; + case SizeExpr::Kind::BoundVariable: + fv.insert(node.var_id); + break; + case SizeExpr::Kind::FieldLoad: + case SizeExpr::Kind::ExternalTensorRead: + collect_idx_vars(node.indices); + break; + case SizeExpr::Kind::Add: + case SizeExpr::Kind::Sub: + case SizeExpr::Kind::Mul: + case SizeExpr::Kind::Max: + fv = result[node.operand_a]; + for (auto v : result[node.operand_b]) + fv.insert(v); + break; + case SizeExpr::Kind::MaxOverRange: { + fv = result[node.operand_a]; + for (auto v : result[node.operand_b]) + fv.insert(v); + // MaxOverRange binds `var_id` for its body only: body's free vars minus this binding add into the + // outer set. + for (auto v : result[node.body_node_idx]) { + if (v != node.var_id) + fv.insert(v); + } + break; + } + } + } + return result; +} + +// Walks `expr` and builds a dense `original_var_id -> [0, N)` map across every `var_id` the tree references +// (`MaxOverRange` binds, `BoundVariable` leaves, and bound-var entries inside each ETR / FieldLoad index list). +// The walker preserves encounter order so nested `MaxOverRange` binds keep monotonically increasing dense ids, +// which also matches the natural `values[]` indexing the device interpreter does at each bind. Hard-errors if +// the tree references more distinct bound vars than the device interpreter's per-stack scope capacity. +std::unordered_map build_dense_var_id_remap(const SerializedSizeExpr &expr) { + std::unordered_map remap; + auto add = [&](int32_t v) { + if (v < 0) + return; + if (remap.find(v) == remap.end()) { + int32_t dense = static_cast(remap.size()); + remap.emplace(v, dense); + } + }; + for (const auto &node : expr.nodes) { + const auto kind = static_cast(node.kind); + if (kind == SizeExpr::Kind::MaxOverRange || kind == SizeExpr::Kind::BoundVariable) + add(node.var_id); + for (int32_t raw : node.indices) { + if (raw < 0) + add(-(raw + 1)); + } + } + QD_ERROR_IF(static_cast(remap.size()) > kAdStackSizeExprDeviceMaxBoundVars, + "Adstack SizeExpr tree references {} distinct bound variable ids, which exceeds the device " + "interpreter's per-stack scope capacity ({}). This almost always indicates a deeply nested " + "reverse-mode loop shape that the pre-pass should have folded earlier; shrink the enclosing " + "loops or file a bug so the grammar / walker can be tightened.", + remap.size(), kAdStackSizeExprDeviceMaxBoundVars); + return remap; +} + +// Computes the maximum `MaxOverRange` nesting depth reachable from any root in `expr`, i.e. the deepest +// chain of `MaxOverRange` nodes whose `body_node_idx` recursively references another `MaxOverRange`. The +// sizer shader's per-invocation pending-frame stack is sized to `kAdStackSizerMaxPendingFrames`; the encoder +// hard-errors when a tree's nesting exceeds this so the shader's fixed-size access-chain stays in bounds +// without a runtime guard. Each node's depth is memoised to keep the walk linear in `expr.nodes.size()`. +int32_t compute_max_mor_nesting(const SerializedSizeExpr &expr) { + std::vector depth(expr.nodes.size(), -1); + std::function visit = [&](int32_t i) -> int32_t { + if (i < 0 || static_cast(i) >= expr.nodes.size()) + return 0; + if (depth[i] >= 0) + return depth[i]; + const auto &n = expr.nodes[i]; + int32_t child_max = 0; + for (int32_t c : {n.operand_a, n.operand_b, n.body_node_idx}) { + if (c >= 0) + child_max = std::max(child_max, visit(c)); + } + int32_t self = static_cast(n.kind) == SizeExpr::Kind::MaxOverRange ? 1 : 0; + depth[i] = self + child_max; + return depth[i]; + }; + int32_t max_depth = 0; + for (std::size_t i = 0; i < expr.nodes.size(); ++i) { + max_depth = std::max(max_depth, visit(static_cast(i))); + } + return max_depth; +} + +// Returns the dense id for `original_var_id`, or fires a hard error if the remap lost track of it (which would +// indicate a walker divergence between `build_dense_var_id_remap` and `encode_subtree`). +int32_t remap_var_id(const std::unordered_map &remap, int32_t original) { + auto it = remap.find(original); + QD_ASSERT_INFO(it != remap.end(), + "Adstack SizeExpr encoder saw var_id={} not present in the dense remap; this " + "is a walker bug between `build_dense_var_id_remap` and `encode_subtree`.", + original); + return it->second; +} + +// Initialises a fresh device node with every unused slot sentinelled so the interpreter can tell them apart from +// legitimate zero-valued slots (e.g. `operand_a == 0` is a valid node index; only `-1` signals "unused"). +AdStackSizeExprDeviceNode make_empty_device_node(int32_t kind) { + AdStackSizeExprDeviceNode dn{}; + dn.kind = kind; + dn.operand_a = -1; + dn.operand_b = -1; + dn.body_node_idx = -1; + dn.var_id = -1; + dn.prim_dt = -1; + dn.arg_buffer_offset = -1; + dn.indices_offset = 0; + dn.indices_count = 0; + dn._pad0 = 0; + dn.const_value = 0; + return dn; +} + +// Recursive top-down encoder. Each call returns the index of the emitted root in `out_nodes`. Subtrees whose +// leaves are all host-resolvable (no `ExternalTensorRead`, and - on the LLVM path - no `FieldLoad` either) and +// whose bound variables are all locally bound within the subtree get folded to a single `kConst` device node +// by running `evaluate_node` over them. On the SPIR-V path, `FieldLoad` also survives as a `kFieldLoad` device +// node alongside `kExternalTensorRead`, so the shader can resolve the snode read in place via PSB. +int32_t encode_subtree(const SerializedSizeExpr &src, + int32_t src_idx, + const std::vector &contains_device_leaf, + const std::vector> &free_vars, + const std::unordered_map &var_id_remap, + Program *prog, + LaunchContextBuilder *ctx, + const FieldLoadDeviceEmitter &fl_emitter, + std::vector &out_nodes, + std::vector &out_indices, + ReadSink *reads) { + QD_ASSERT_INFO(src_idx >= 0 && static_cast(src_idx) < src.nodes.size(), + "encode_subtree: src_idx {} out of bounds (size={})", src_idx, src.nodes.size()); + const bool subtree_needs_device = contains_device_leaf[src_idx]; + const bool subtree_closed = free_vars[src_idx].empty(); + + if (!subtree_needs_device && subtree_closed) { + // Whole subtree resolves without any device-resident read and without an outer-iteration context, so fold it + // to a single `Const` by running the host evaluator over it. This is the only path that can substitute + // `FieldLoad` / `ExternalTensorShape` leaves - the device interpreter does not know how to walk SNodes or + // index into `args_type`. + std::unordered_map empty_bound; + int64_t val = evaluate_node(src, src_idx, empty_bound, prog, ctx, reads); + AdStackSizeExprDeviceNode dn = make_empty_device_node(static_cast(AdStackSizeExprDeviceKind::kConst)); + dn.const_value = val; + out_nodes.push_back(dn); + return static_cast(out_nodes.size() - 1); + } + + const auto &node = src.nodes[src_idx]; + const auto kind = static_cast(node.kind); + switch (kind) { + case SizeExpr::Kind::Const: { + AdStackSizeExprDeviceNode dn = make_empty_device_node(static_cast(AdStackSizeExprDeviceKind::kConst)); + dn.const_value = node.const_value; + out_nodes.push_back(dn); + return static_cast(out_nodes.size() - 1); + } + case SizeExpr::Kind::BoundVariable: { + AdStackSizeExprDeviceNode dn = + make_empty_device_node(static_cast(AdStackSizeExprDeviceKind::kBoundVariable)); + dn.var_id = remap_var_id(var_id_remap, node.var_id); + out_nodes.push_back(dn); + return static_cast(out_nodes.size() - 1); + } + case SizeExpr::Kind::Add: + case SizeExpr::Kind::Sub: + case SizeExpr::Kind::Mul: + case SizeExpr::Kind::Max: { + int32_t a = encode_subtree(src, node.operand_a, contains_device_leaf, free_vars, var_id_remap, prog, ctx, + fl_emitter, out_nodes, out_indices, reads); + int32_t b = encode_subtree(src, node.operand_b, contains_device_leaf, free_vars, var_id_remap, prog, ctx, + fl_emitter, out_nodes, out_indices, reads); + AdStackSizeExprDeviceKind dk = AdStackSizeExprDeviceKind::kAdd; + if (kind == SizeExpr::Kind::Sub) + dk = AdStackSizeExprDeviceKind::kSub; + else if (kind == SizeExpr::Kind::Mul) + dk = AdStackSizeExprDeviceKind::kMul; + else if (kind == SizeExpr::Kind::Max) + dk = AdStackSizeExprDeviceKind::kMax; + AdStackSizeExprDeviceNode dn = make_empty_device_node(static_cast(dk)); + dn.operand_a = a; + dn.operand_b = b; + out_nodes.push_back(dn); + return static_cast(out_nodes.size() - 1); + } + case SizeExpr::Kind::MaxOverRange: { + int32_t a = encode_subtree(src, node.operand_a, contains_device_leaf, free_vars, var_id_remap, prog, ctx, + fl_emitter, out_nodes, out_indices, reads); + int32_t b = encode_subtree(src, node.operand_b, contains_device_leaf, free_vars, var_id_remap, prog, ctx, + fl_emitter, out_nodes, out_indices, reads); + // Iteration cap pre-check at encode time. Mirrors the host evaluator's `QD_ERROR_IF` in `evaluate_node`'s + // `MaxOverRange` arm. The recognizer's parallel-reducer pass substitutes recognized shapes by `Const` before the + // encoder walks the tree, so any `MaxOverRange` reaching this branch is out-of-grammar; the device sizer's + // `kMaxOverRange` arm short-circuits past the cap to keep the on-device walk inside the driver's TDR window, but + // on the LLVM-GPU paths there is no host-visible signal afterwards (the release-mode codegen for + // `AdStackPushStmt` drops the `n + 1 > max_num_elements` guard, so the corner-thread overflow is silent). Raising + // here surfaces the cap-hit as a `RuntimeError` from the launcher before any device dispatch. Both `begin` and + // `end` come from operand subtrees the encoder already lowered above; the encoder folds any closed, + // host-resolvable subtree to `kConst` at the top of this function, so the post-encode kind check captures every + // shape host-resolvable at encode time without re-evaluating the operand subtrees. If either operand still + // references a free outer-scope bound variable (nested-`MaxOverRange` ragged case), the operand is not `kConst` + // and we fall through to the device-side short-circuit. + if (out_nodes[a].kind == static_cast(AdStackSizeExprDeviceKind::kConst) && + out_nodes[b].kind == static_cast(AdStackSizeExprDeviceKind::kConst)) { + const int64_t begin_v = out_nodes[a].const_value; + const int64_t end_v = out_nodes[b].const_value; + constexpr int64_t kMaxOverRangeIterations = int64_t{1} << 24; + QD_ERROR_IF(end_v > begin_v && end_v - begin_v > kMaxOverRangeIterations, + "SerializedSizeExpr MaxOverRange iteration count {} exceeds the {} guard; refusing to " + "enumerate. Shrink the enclosing reverse-mode loop or restructure the `SizeExpr` source " + "kernel.", + end_v - begin_v, kMaxOverRangeIterations); + } + int32_t body = encode_subtree(src, node.body_node_idx, contains_device_leaf, free_vars, var_id_remap, prog, ctx, + fl_emitter, out_nodes, out_indices, reads); + AdStackSizeExprDeviceNode dn = + make_empty_device_node(static_cast(AdStackSizeExprDeviceKind::kMaxOverRange)); + dn.operand_a = a; + dn.operand_b = b; + dn.body_node_idx = body; + dn.var_id = remap_var_id(var_id_remap, node.var_id); + out_nodes.push_back(dn); + return static_cast(out_nodes.size() - 1); + } + case SizeExpr::Kind::ExternalTensorRead: { + QD_ASSERT_INFO(ctx != nullptr && ctx->args_type != nullptr, + "encode_subtree: ExternalTensorRead at node {} requires a LaunchContextBuilder with a valid " + "args_type to precompute the data_ptr offset", + src_idx); + QD_ASSERT_INFO(!node.arg_id_path.empty(), "ExternalTensorRead at node {} has empty arg_id_path", src_idx); + std::vector arg_indices(node.arg_id_path.begin(), node.arg_id_path.end()); + arg_indices.push_back(TypeFactory::DATA_PTR_POS_IN_NDARRAY); + const size_t data_ptr_offset = ctx->args_type->get_element_offset(arg_indices); + AdStackSizeExprDeviceNode dn = + make_empty_device_node(static_cast(AdStackSizeExprDeviceKind::kExternalTensorRead)); + // Cast to i32 is safe: `arg_buffer` sizes in practice are kilobytes, well under INT32_MAX. + dn.arg_buffer_offset = static_cast(data_ptr_offset); + dn.prim_dt = static_cast(node.const_value); // the pre-pass stashes `PrimitiveTypeID` in const_value + dn.indices_offset = static_cast(out_indices.size()); + dn.indices_count = static_cast(node.indices.size()); + // Pre-compute per-axis element strides in C order (`stride[k] = prod_{m > k} shape[m]`). Shapes live in + // the kernel args struct as `int32` slots at the `SHAPE_POS_IN_NDARRAY` path, same source the host + // evaluator reads; using the live launch context keeps strides consistent with whichever ndarray the + // user handed to the kernel on this launch. Emit as `[idx_a_raw, elem_stride_a]` pairs per axis, + // matching the `kFieldLoad` layout so the device interpreter and SPIR-V sizer shader can share one + // pair-walking offset-computation loop instead of carrying a separate stride-1 path. + std::vector elem_strides(node.indices.size(), 1); + if (node.indices.size() > 1) { + for (std::size_t k = node.indices.size(); k-- > 0;) { + if (k + 1 < node.indices.size()) { + std::vector sh_idx(node.arg_id_path.begin(), node.arg_id_path.end()); + sh_idx.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY); + sh_idx.push_back(static_cast(k + 1)); + int32_t sh = ctx->get_struct_arg_host(sh_idx); + elem_strides[k] = elem_strides[k + 1] * sh; + } + } + } + for (std::size_t k = 0; k < node.indices.size(); ++k) { + int32_t raw = node.indices[k]; + if (raw < 0) { + // Remap bound-variable refs so the device interpreter's `scope->values[var]` read lands in the + // `[0, kAdStackSizeExprDeviceMaxBoundVars)` range regardless of how large the source tree's + // `var_id_counter` grew across its push-site walks. + int32_t dense = remap_var_id(var_id_remap, -(raw + 1)); + raw = -(dense + 1); + } + out_indices.push_back(raw); + out_indices.push_back(elem_strides[k]); + } + out_nodes.push_back(dn); + return static_cast(out_nodes.size() - 1); + } + case SizeExpr::Kind::FieldLoad: { + // If we reach here the subtree is not host-substitutable (has free bound vars or sits alongside an + // `ExternalTensorRead` in the same closed context, or - on the SPIR-V path - `FieldLoad` is deliberately + // kept on the device via `fl_emitter`). Without an emitter, the LLVM path would have folded it earlier; + // reaching here without one means the shape is outside what the grammar supports, which is a user-facing + // bug, not a runtime fallback. + QD_ASSERT_INFO( + !fl_emitter.empty(), + "Adstack SizeExpr FieldLoad at node {} survived the host fold without a FieldLoadDeviceEmitter. The " + "LLVM encoder should route closed FieldLoads through `evaluate_node` and reject non-closed ones before " + "the structural pre-pass emits them; if this fires, a SerializedSizeExpr with a bound-var-indexed " + "FieldLoad leaf reached an LLVM-targeted encoder (which cannot resolve it on-device).", + src_idx); + QD_ASSERT_INFO(node.snode_id >= 0, "FieldLoad at node {} has no snode_id", src_idx); + QD_ASSERT_INFO(prog != nullptr, "encode_subtree: FieldLoad needs a live Program to resolve snode {}", + node.snode_id); + SNode *snode = prog->get_snode_by_id(node.snode_id); + QD_ASSERT_INFO(snode != nullptr, + "FieldLoad at node {} references snode_id={} which is not in the program's snode tree", src_idx, + node.snode_id); + uint64_t base_psb = 0; + std::vector elem_strides; + bool fetched = fl_emitter.fetch(snode, &base_psb, &elem_strides); + QD_ERROR_IF(!fetched, + "Adstack SizeExpr FieldLoad at node {} on snode_id={} could not be resolved for device-side " + "evaluation: the snode layout is not a pure-dense chain ending in a plain place leaf (bitmasked " + "/ pointer / bit-level snodes are not supported by the SPIR-V sizer shader). Rewrite the trip " + "count to use a dense field, or extend the shader to walk the non-dense hierarchy.", + src_idx, node.snode_id); + QD_ASSERT_INFO(elem_strides.size() == node.indices.size(), + "FieldLoad at node {}: elem_strides.size()={} must match node.indices.size()={} (one stride " + "per active axis)", + src_idx, elem_strides.size(), node.indices.size()); + AdStackSizeExprDeviceNode dn = + make_empty_device_node(static_cast(AdStackSizeExprDeviceKind::kFieldLoad)); + dn.const_value = static_cast(base_psb); + // `PrimitiveTypeID` for the leaf: mirrors ExternalTensorRead's field. The pre-pass emits a `FieldLoad` + // `SerializedSizeExprNode` with `snode_id` set and the element type implicit in the snode; we look it up + // here so the shader's existing `emit_psb_load_i64` switch (shared with ETR) can dispatch on it. + dn.prim_dt = static_cast(snode->dt->cast()->type); + dn.indices_offset = static_cast(out_indices.size()); + dn.indices_count = static_cast(node.indices.size()); + // Interleaved `[idx_a_raw, elem_stride_a]` pairs per axis. The shader reads 2 i32s per axis and + // accumulates `idx_a * stride_a` into the element index, then `psb_load_scalar` multiplies by the + // element size to get the final byte offset. Bound-variable refs (negative entries) are dense-remapped + // so the device interpreter's fixed-size `scope->values[]` stays in bounds. + for (std::size_t a = 0; a < node.indices.size(); ++a) { + int32_t raw = static_cast(node.indices[a]); + if (raw < 0) { + int32_t dense = remap_var_id(var_id_remap, -(raw + 1)); + raw = -(dense + 1); + } + out_indices.push_back(raw); + out_indices.push_back(elem_strides[a]); + } + out_nodes.push_back(dn); + return static_cast(out_nodes.size() - 1); + } + case SizeExpr::Kind::ExternalTensorShape: { + // Should have been folded to `Const` by the `subtree_needs_device == false && subtree_closed == true` + // branch above, since `ExternalTensorShape` has no free vars and cannot be an `ExternalTensorRead`. Hitting + // this branch is a bug in the encoder, not in the kernel. + QD_ERROR( + "Adstack SizeExpr ExternalTensorShape at node {} escaped host folding: this is an encoder invariant" + " violation - shape nodes are always closed and should have been emitted as Const.", + src_idx); + return -1; + } + } + QD_ERROR("encode_subtree: unreachable kind {} at node {}", node.kind, src_idx); + return -1; +} + +// Shared back-end for both encoder variants. Takes already-populated stack headers (with +// `entry_size_bytes` / `max_size_compile_time` / `heap_kind` set per stack, `root_node_idx` defaulted to +// `-1`) plus the per-stack source trees, runs the tree-to-bytecode substitution-aware flattening, and +// returns the packed byte buffer ready to upload to a device scratch buffer. +std::vector encode_bytecode_common(std::vector stack_headers, + const std::vector &exprs, + Program *prog, + LaunchContextBuilder *ctx, + const FieldLoadDeviceEmitter &fl_emitter, + int max_nodes_per_stack = 0, + ReadSink *reads = nullptr) { + const std::size_t n_stacks = stack_headers.size(); + QD_ASSERT(exprs.size() == n_stacks); + + std::vector nodes; + std::vector indices; + nodes.reserve(n_stacks); + indices.reserve(n_stacks); + + const bool fieldload_stays_on_device = !fl_emitter.empty(); + for (std::size_t i = 0; i < n_stacks; ++i) { + auto &sh = stack_headers[i]; + const SerializedSizeExpr *expr = exprs[i]; + if (std::getenv("QD_DEBUG_ADSTACK")) { + fprintf(stderr, "[encode] stack[%zu]: expr=%p nodes=%zu max_size_ct=%u\n", i, (void *)expr, + expr ? expr->nodes.size() : 0, sh.max_size_compile_time); + if (expr) { + for (size_t n = 0; n < expr->nodes.size(); ++n) { + const auto &node = expr->nodes[n]; + fprintf(stderr, + "[encode] node[%zu]: kind=%d const=%lld snode_id=%d var_id=%d op_a=%d op_b=%d body=%d " + "axis=%d", + n, node.kind, (long long)node.const_value, node.snode_id, node.var_id, node.operand_a, node.operand_b, + node.body_node_idx, node.arg_shape_axis); + if (!node.arg_id_path.empty()) { + fprintf(stderr, " arg_id=["); + for (int32_t v : node.arg_id_path) + fprintf(stderr, "%d,", v); + fprintf(stderr, "]"); + } + if (!node.indices.empty()) { + fprintf(stderr, " idx=["); + for (int32_t v : node.indices) + fprintf(stderr, "%d,", v); + fprintf(stderr, "]"); + } + fprintf(stderr, "\n"); + } + // Host-side ground-truth evaluation: if the shader later writes a different `max_size`, the delta + // pinpoints a shader-side bug rather than a pre-pass / SerializedSizeExpr bug. Skip when the caller + // passes `ctx == nullptr` (C++-only test harnesses) and when an `ExternalTensorRead` leaf exists but + // `ctx->array_ptrs` has not been populated (the CPU launcher populates it via + // `set_host_accessible_ndarray_ptrs`; SPIR-V launchers use the device-side PSB path instead, so + // `array_ptrs` is empty and the host-eval would crash on the missing key). + if (!expr->nodes.empty() && prog != nullptr && ctx != nullptr) { + bool has_etr = false; + for (const auto &node : expr->nodes) { + if (static_cast(node.kind) == SizeExpr::Kind::ExternalTensorRead) { + has_etr = true; + break; + } + } + if (has_etr && ctx->array_ptrs.empty()) { + fprintf(stderr, "[encode] stack[%zu]: host_eval=skipped (ctx->array_ptrs empty)\n", i); + } else { + int64_t host_val = evaluate_adstack_size_expr(*expr, prog, ctx); + fprintf(stderr, "[encode] stack[%zu]: host_eval=%lld\n", i, (long long)host_val); + } + } + } + } + if (expr == nullptr || expr->nodes.empty()) { + // No symbolic bound captured - the device interpreter will route this slot to `max_size_compile_time`. + sh.root_node_idx = -1; + continue; + } + auto contains_device_leaf = compute_contains_device_leaf(*expr, fieldload_stays_on_device); + auto free_vars = compute_free_vars(*expr); + const std::size_t root_src_idx = expr->nodes.size() - 1; + QD_ASSERT_INFO(free_vars[root_src_idx].empty(), + "Adstack SizeExpr tree root for stack {} has {} free bound variable(s); a well-formed tree" + " must be closed at the root because no outer MaxOverRange scope exists at publish time.", + i, free_vars[root_src_idx].size()); + // Dense-remap the tree's `var_id`s before emitting device nodes: `var_id_counter` on the host is a monotonic + // per-alloca counter bumped at every chased non-const index / stash, so a complex reverse-mode kernel can + // exceed the device interpreter's fixed-size scope capacity even with modest nesting. The encoder hard-errors + // here rather than letting the interpreter silently drop binds and return wrong `max_size` values. + auto var_id_remap = build_dense_var_id_remap(*expr); + const int32_t mor_depth = compute_max_mor_nesting(*expr); + QD_ERROR_IF(mor_depth > spirv::kAdStackSizerMaxPendingFrames, + "Adstack SizeExpr for stack {} has MaxOverRange nesting depth {}, which exceeds the sizer shader's " + "`kAdStackSizerMaxPendingFrames` ({}) pending-frame capacity. Past this cap the shader's fixed-size " + "pending-frame stack would index out of bounds - SPIR-V private-storage OOB is UB. Shrink the " + "enclosing reverse-mode loop nesting or file a bug so the cap can be raised.", + i, mor_depth, spirv::kAdStackSizerMaxPendingFrames); + const std::size_t nodes_before = nodes.size(); + sh.root_node_idx = encode_subtree(*expr, static_cast(root_src_idx), contains_device_leaf, free_vars, + var_id_remap, prog, ctx, fl_emitter, nodes, indices, reads); + if (max_nodes_per_stack > 0) { + const std::size_t per_stack = nodes.size() - nodes_before; + QD_ERROR_IF(per_stack > static_cast(max_nodes_per_stack), + "Adstack SizeExpr for stack {} encodes {} device nodes, which exceeds the sizer shader's per-stack " + "`kAdStackSizerMaxNodesPerStack` ({}) scratch capacity. Shrink the reverse-mode loop shape or file a " + "bug - past this cap the on-device interpreter would silently truncate its private `values_arr` and " + "surface later as a mysterious adstack overflow.", + i, per_stack, max_nodes_per_stack); + } + } + + // Pack everything into a flat byte buffer: header | stack_headers | nodes | indices. + AdStackSizeExprDeviceHeader header{}; + header.n_stacks = static_cast(n_stacks); + header.total_nodes = static_cast(nodes.size()); + header.total_indices = static_cast(indices.size()); + header._pad = 0; + + const std::size_t bytes_header = sizeof(AdStackSizeExprDeviceHeader); + const std::size_t bytes_stack_headers = sizeof(AdStackSizeExprDeviceStackHeader) * n_stacks; + const std::size_t bytes_nodes = sizeof(AdStackSizeExprDeviceNode) * nodes.size(); + const std::size_t bytes_indices = sizeof(int32_t) * indices.size(); + const std::size_t total_bytes = bytes_header + bytes_stack_headers + bytes_nodes + bytes_indices; + + std::vector buffer(total_bytes); + std::size_t cursor = 0; + std::memcpy(buffer.data() + cursor, &header, bytes_header); + cursor += bytes_header; + if (bytes_stack_headers > 0) { + std::memcpy(buffer.data() + cursor, stack_headers.data(), bytes_stack_headers); + cursor += bytes_stack_headers; + } + if (bytes_nodes > 0) { + std::memcpy(buffer.data() + cursor, nodes.data(), bytes_nodes); + cursor += bytes_nodes; + } + if (bytes_indices > 0) { + std::memcpy(buffer.data() + cursor, indices.data(), bytes_indices); + cursor += bytes_indices; + } + QD_ASSERT(cursor == total_bytes); + return buffer; +} + +} // namespace + +std::vector encode_adstack_size_expr_device_bytecode(const AdStackSizingInfo &ad_stack, + Program *prog, + LaunchContextBuilder *ctx, + const MaxReducerResultMap &max_reducer_results) { + const std::size_t n_stacks = ad_stack.allocas.size(); + std::vector stack_headers(n_stacks); + std::vector exprs(n_stacks, nullptr); + // Per-stack substituted trees: if the max-reducer dispatched a value for any + // captured `MaxOverRange`, swap it in as a `Const` BEFORE the device interpreter walks the tree. Storage owns + // the substituted copies so `exprs[i]` (a pointer) remains valid through `encode_bytecode_common`. + std::vector substituted_storage(n_stacks); + for (std::size_t i = 0; i < n_stacks; ++i) { + stack_headers[i].entry_size_bytes = static_cast(ad_stack.allocas[i].entry_size_bytes); + stack_headers[i].max_size_compile_time = static_cast(ad_stack.allocas[i].max_size_compile_time); + // Float allocas land on the lazy float heap, int allocas on the eager int heap. The encoding (`0` = float, `1` = + // int) matches the SPIR-V `AdStackHeapKind` so the offline-cache bytecode survives a backend swap. + stack_headers[i].heap_kind = (ad_stack.allocas[i].heap_kind == AdStackAllocaInfo::HeapKind::Float) ? 0u : 1u; + if (i < ad_stack.size_exprs.size()) { + if (!max_reducer_results.empty()) { + substituted_storage[i] = substitute_precomputed_max_over_range(ad_stack.size_exprs[i], ad_stack.registry_id, + static_cast(i), max_reducer_results); + exprs[i] = &substituted_storage[i]; + } else { + exprs[i] = &ad_stack.size_exprs[i]; + } + } + } + // LLVM path: default-constructed emitter routes every FieldLoad through the host-fold (via `read_int`). That + // is safe on CPU / CUDA / AMDGPU where a nested accessor kernel launch does not conflict with the enclosing + // kernel prep. + FieldLoadDeviceEmitter fl_emitter{}; + return encode_bytecode_common(std::move(stack_headers), exprs, prog, ctx, fl_emitter); +} + +bool compute_dense_snode_strides(SNode *leaf, std::vector *out_elem_strides) { + if (leaf == nullptr) { + return false; + } + if (leaf->type != SNodeType::place) { + return false; + } + if (!leaf->is_path_all_dense) { + // A pointer / bitmasked / hash ancestor requires an on-device activation lookup the sizer shader does not + // implement. Pushing that into the shader would mean pulling the full SNode codegen subsystem in; refuse here + // and let the caller raise a user-visible "dense only" error. + return false; + } + if (leaf->is_bit_level) { + return false; // quant array / bit-struct leaves need bit-packing logic we do not emit here + } + // Refuse multi-child dense parents. The stride computation below assumes the place leaf is the sole occupant of its + // parent dense cell: `prod(shape[k+1..])` is a valid element-unit stride only when the physical cell size equals + // `sizeof(leaf_dtype)`. With multiple `.place(...)` siblings under the same dense ancestor (AoS layout), the real + // per-axis element stride is `cell_size / sizeof(leaf_dtype)`, so this function's output would land on a sibling + // field at `i >= 1`. Extending to cell-size-aware strides would require walking `SNodeDescriptor` memory-offset + // metadata the sizer shader does not consume today; refuse and surface a clear "dense-only, single-place parent" + // error instead. + for (const SNode *anc = leaf; anc != nullptr && anc->parent != nullptr; anc = anc->parent) { + const SNode *p = anc->parent; + if (p->type == SNodeType::dense && p->ch.size() > 1) { + return false; + } + } + const int n = leaf->num_active_indices; + if (n < 0) { + return false; + } + // Scalar fields (`qd.field(dt, shape=())`) have `num_active_indices == 0`; the pre-pass emits a `FieldLoad` with an + // empty `indices` vector and the shader should just load `*base_psb` without any index computation. Return an empty + // strides vector - `compute_field_load_elem_index`'s loop iterates zero times and produces `elem_idx = 0`, which + // `psb_load_scalar` resolves to the exact place address. + std::vector shape(n, 0); + for (int a = 0; a < n; ++a) { + int s = leaf->shape_along_axis(a); + if (s <= 0) { + return false; + } + shape[a] = s; + } + out_elem_strides->resize(n); + for (int a = 0; a < n; ++a) { + int64_t stride = 1; + for (int b = a + 1; b < n; ++b) { + stride *= shape[b]; + if (stride > std::numeric_limits::max()) { + return false; // would overflow the i32 slot; refuse rather than encode a truncated stride + } + } + (*out_elem_strides)[a] = static_cast(stride); + } + return true; +} + +std::vector encode_adstack_size_expr_device_bytecode_for_spirv( + const spirv::TaskAttributes::AdStackSizingAttribs &ad_stack, + Program *prog, + LaunchContextBuilder *ctx, + const MaxReducerResultMap &max_reducer_results) { + const std::size_t n_stacks = ad_stack.allocas.size(); + std::vector stack_headers(n_stacks); + std::vector exprs(n_stacks, nullptr); + // Per-stack substituted trees. when the max-reducer dispatched a value for + // a captured `MaxOverRange` node, substitute it as a `Const` BEFORE the device sizer encoder walks the tree. + // Storage owns the substituted copies so `exprs[i]` (a pointer) stays valid through `encode_bytecode_common`. + std::vector substituted_storage(n_stacks); + for (std::size_t i = 0; i < n_stacks; ++i) { + const auto &a = ad_stack.allocas[i]; + // The SPIR-V heaps are element-indexed (f32 / i32), so `entry_size_bytes` in the device header would be + // misnamed if we set it to the byte count; the SPIR-V sizer shader interprets this field as element count + // and scales by `2` only for the `Float` heap (to cover primal + adjoint interleaved), matching the + // `running_offset_float += 2u * max_size` / `running_offset_int += max_size` convention the host path used + // to perform and the main-kernel code already bakes into its offset arithmetic. Stamp `1` here so the + // sizer's multiplication by `2` for the float heap lands exactly on `2 * max_size` and the int heap on + // `1 * max_size`. + stack_headers[i].entry_size_bytes = 1; + stack_headers[i].max_size_compile_time = a.max_size_compile_time; + stack_headers[i].heap_kind = static_cast(a.heap_kind); // Float = 0, Int = 1 + if (!max_reducer_results.empty()) { + substituted_storage[i] = substitute_precomputed_max_over_range(a.size_expr, ad_stack.registry_id, + static_cast(i), max_reducer_results); + exprs[i] = &substituted_storage[i]; + } else { + exprs[i] = &a.size_expr; + } + } + // SPIR-V path: emit `FieldLoad` as `kFieldLoad` device nodes so the sizer shader can PSB-load the field value + // in place. This avoids `SNodeRwAccessorsBank::Accessors::read_int`, whose nested accessor-kernel launch + // deadlocks inside MoltenVK's descriptor-set bind path when the outer launch has already opened its command + // buffer. The emitter resolves each snode's tree-root PSB via the program's compute device and pre-computes + // the per-axis byte strides from the dense snode shape. + FieldLoadDeviceEmitter fl_emitter{}; + fl_emitter.fetch = [prog](SNode *snode, uint64_t *out_base_psb, std::vector *out_elem_strides) -> bool { + if (snode == nullptr || prog == nullptr) { + return false; + } + if (!compute_dense_snode_strides(snode, out_elem_strides)) { + return false; + } + const int tree_id = snode->get_snode_tree_id(); + DevicePtr tree_root_devptr = prog->get_snode_tree_device_ptr(tree_id); + Device *dev = prog->get_compute_device(); + if (dev == nullptr) { + return false; + } + // `get_memory_physical_pointer` returns the Vulkan `bufferDeviceAddress` / Metal equivalent for the buffer that + // backs the snode tree's root. The place's byte offset within the tree comes from the compiled snode descriptor + // table (`snode_descriptors[id].mem_offset_in_parent_cell` walked up to root), NOT from + // `SNode::offset_bytes_in_parent_cell` which is a frontend-only field that stays zero on the SPIR-V path. Using the + // wrong offset silently reads a sibling field (typically the first `qd.field` declared in the program), which looks + // like a returning-zero bug at runtime. + uint64_t root_psb = dev->get_memory_physical_pointer(tree_root_devptr); + if (root_psb == 0) { + return false; + } + size_t place_byte_offset = prog->get_field_in_tree_offset(tree_id, snode); + if (std::getenv("QD_DEBUG_ADSTACK")) { + // Pull the live value via the RwAccessors as a ground-truth check on the `PSB + place_off` pair the + // sizer shader will consume: if the shader later reads a different i32 than `live_val`, the byte offset + // we encoded is wrong even though `shape_along_axis` lined up and the tree dispatcher emitted a sensible + // PSB base. `read_int` launches its own accessor kernel plus a `synchronize()`, which is safe here + // because the encoder runs outside any in-flight main-kernel launch. + std::vector idx_zero(snode->num_active_indices, 0); + int64_t live_val = prog->get_snode_rw_accessors_bank().get(snode).read_int(idx_zero); + fprintf(stderr, + "[fl.fetch] snode_id=%d type=%d dense=%d n_axes=%d tree_id=%d root_psb=0x%llx place_off=%zu live=%lld\n", + snode->id, (int)snode->type, (int)snode->is_path_all_dense, snode->num_active_indices, tree_id, + (unsigned long long)root_psb, place_byte_offset, (long long)live_val); + } + *out_base_psb = root_psb + static_cast(place_byte_offset); + return true; + }; + // Bytecode fast path: replay the recorded host-fold reads against the live state and reuse the cached + // bytecode if every input still matches. The full encode runs only on cache miss. + if (prog != nullptr) { + std::vector cached; + if (prog->adstack_cache().try_spirv_bytecode_cache_hit(prog, static_cast(&ad_stack), ctx, cached)) { + return cached; + } + } + std::vector reads; + std::vector bytecode = encode_bytecode_common(std::move(stack_headers), exprs, prog, ctx, fl_emitter, + spirv::kAdStackSizerMaxNodesPerStack, &reads); + // Thread the max-reducer body's read observations into the bytecode cache entry so a mutation to the gating + // ndarray invalidates the cached bytecode (the encoder walked the post-substitution tree where each captured + // `MaxOverRange` has collapsed to a `Const`, so the body's `ExternalTensorRead` leaves are not in `reads`). + // The observations were populated by the dispatch site via `populate_max_reducer_body_observations` and + // recorded into the `max_reducer_cache_` alongside the dispatched value. On a subsequent launch the bytecode + // cache replays them; gen-mismatch paths hit the dereference branch in `replay_one_observation` which returns + // a value other than the recorded `INT64_MIN` sentinel and forces invalidation. + if (prog != nullptr) { + for (const auto &spec : ad_stack.max_reducer_specs) { + const auto *spec_reads = + prog->adstack_cache().lookup_max_reducer_reads(ad_stack.registry_id, spec.stack_id, spec.mor_node_idx); + if (spec_reads != nullptr) { + reads.insert(reads.end(), spec_reads->begin(), spec_reads->end()); + } + } + prog->adstack_cache().record_spirv_bytecode_eval(static_cast(&ad_stack), bytecode, std::move(reads)); + } + return bytecode; +} + +} // namespace quadrants::lang diff --git a/quadrants/program/adstack/device_bytecode.h b/quadrants/program/adstack/device_bytecode.h new file mode 100644 index 0000000000..ab26c02f28 --- /dev/null +++ b/quadrants/program/adstack/device_bytecode.h @@ -0,0 +1,85 @@ +#pragma once + +#include +#include +#include + +#include "quadrants/codegen/llvm/llvm_compiled_data.h" +#include "quadrants/codegen/spirv/kernel_utils.h" +#include "quadrants/ir/adstack_size_expr.h" +#include "quadrants/program/adstack/max_reducer.h" + +namespace quadrants::lang { + +class LaunchContextBuilder; +class Program; +class SNode; + +// Data needed to encode a `FieldLoad` as a `kFieldLoad` device node on the SPIR-V backend. Populated by the SPIR-V +// dispatch site (per-task sizer or max-reducer) via `GfxRuntime` / `Device` queries; the LLVM encoder paths pass a +// default-constructed (`empty()`) emitter and resolve `(snode_root_id, place_byte_offset)` directly via `prog`. The +// fetch closure returns `out_base_psb = root_buffer_psb + place_byte_offset_in_root` and per-active-axis element +// strides (in units of the leaf primitive type, not bytes - the shader multiplies by `sizeof(prim_dt)` separately via +// `psb_load_scalar`). Returns false when the snode is not amenable to direct PSB indexing (bitmasked / pointer / hash +// chain, bit-level place, not-all-dense path); the encoder treats that as a hard error on the per-task sizer path or +// drops the spec on the max-reducer path. +struct FieldLoadDeviceEmitter { + std::function *out_elem_strides)> fetch; + + bool empty() const { + return fetch == nullptr; + } +}; + +// Compute per-active-axis element strides for a dense `place`-leaf SNode (units = leaf primitive type, not bytes). +// Matches the SPIR-V FieldLoad emitter's stride convention; the max-reducer encoder reuses this to lay out the +// `[idx_a_raw, elem_stride_a]` indices-table pairs that the body interpreter walks. Returns false on non-dense / +// bit-level / multi-child-dense layouts (same restriction as the per-task sizer's `FieldLoadDeviceEmitter::fetch`). +bool compute_dense_snode_strides(SNode *leaf, std::vector *out_elem_strides); + +// Flattens every alloca's `SerializedSizeExpr` tree into the device-readable bytecode defined in +// `quadrants/ir/adstack_size_expr_device.h` and returns the raw bytes ready to upload to a device scratch buffer. +// Two transforms happen at encoding time: +// +// 1. Pre-substitution of host-resolvable subtrees. Any subtree whose leaves consist only of `Const`, +// `BoundVariable`, `FieldLoad`, and `ExternalTensorShape` nodes - i.e. nothing that requires an +// on-device pointer dereference - is collapsed to a single `Const` node by running the existing host +// evaluator over it. This routes `FieldLoad` through `SNodeRwAccessorsBank::read_int` (which itself +// handles device-to-host via a tiny reader kernel on GPU) and `ExternalTensorShape` through the kernel +// arg buffer that the host just wrote, so the device interpreter in `runtime.cpp` never has to walk +// an SNode tree or index into `args_type` - it only has to handle arithmetic plus +// `ExternalTensorRead`, which is the one leaf kind that actually needs device-resident memory. +// 2. `arg_buffer_offset` precomputation. Every surviving `ExternalTensorRead` carries the byte offset into +// `RuntimeContext::arg_buffer` where the referenced ndarray's data pointer lives, resolved here against +// `ctx->args_type->get_element_offset({arg_id, DATA_PTR_POS_IN_NDARRAY})`. The device interpreter does +// a direct `*(void **)(arg_buffer + offset)` to fetch the ndarray pointer at launch time - no map +// lookup, no `LaunchContextBuilder` touches from device code. + +// Mixed subtrees that contain both an `ExternalTensorRead` and a `FieldLoad` are rejected with a hard error: +// the device interpreter does not support on-device SNode access, so a `FieldLoad` that cannot be lifted out +// to a host-resolvable `Const` has nowhere to run. The grammar today does not emit this combination and no +// user kernel has been observed to do so; the hard error pins the assumption so a future regression cannot +// slip past. +std::vector encode_adstack_size_expr_device_bytecode( + const AdStackSizingInfo &ad_stack, + Program *prog, + LaunchContextBuilder *ctx, + const MaxReducerResultMap &max_reducer_results = MaxReducerResultMap{}); + +// SPIR-V-flavour encoder. Same transforms as the LLVM variant, but sources per-stack metadata from +// `TaskAttributes::AdStackSizingAttribs::allocas` (each entry has a `HeapKind` - `Float = 0`, `Int = 1` - +// that routes the stack onto the `AdStackHeapFloat` or `AdStackHeapInt` backing buffer on the host). The +// `heap_kind` field of each `AdStackSizeExprDeviceStackHeader` carries that selector into the shader; the +// shader splits the running-offset / stride computation into a float accumulator and an int accumulator so +// the output metadata buffer matches the layout the main kernel already reads today: +// `[stride_float, stride_int, (offset_i, max_size_i)*]`. The `entry_size_bytes` field is set to 1 on the +// SPIR-V path because the backing buffers are element-indexed (f32 / i32) rather than byte-indexed and the +// shader multiplies by `2` only for the `Float` heap (primal + adjoint interleaved) - see the running-offset +// arithmetic in `GfxRuntime::launch_kernel` for the convention this matches. +std::vector encode_adstack_size_expr_device_bytecode_for_spirv( + const spirv::TaskAttributes::AdStackSizingAttribs &ad_stack, + Program *prog, + LaunchContextBuilder *ctx, + const MaxReducerResultMap &max_reducer_results = MaxReducerResultMap{}); + +} // namespace quadrants::lang diff --git a/quadrants/program/adstack/diagnose.cpp b/quadrants/program/adstack/diagnose.cpp new file mode 100644 index 0000000000..e05def506a --- /dev/null +++ b/quadrants/program/adstack/diagnose.cpp @@ -0,0 +1,297 @@ +#include "quadrants/program/adstack/diagnose.h" + +#include +#include +#include +#include +#include +#include + +#include "quadrants/common/logging.h" +#include "quadrants/ir/type.h" +#include "quadrants/ir/type_factory.h" +#include "quadrants/program/adstack/cache.h" +#include "quadrants/program/adstack/eval.h" +#include "quadrants/program/launch_context_builder.h" +#include "quadrants/program/program.h" +#include "quadrants/rhi/device.h" + +namespace quadrants::lang { + +namespace { + +// Diagnose-time leaf reader: resolves an `ExternalTensorRead` against the captured +// `AdStackCache::DiagnoseLaunchSnapshot` and the program's `Device::map` interface. Returns -1 on any failure +// (missing arg in snapshot, unrecognised primitive type, mapping failure) so the caller can substitute the +// `?` placeholder for that stack while keeping the rest of the message intact. +// +// Single-scalar staging-buffer pattern (mirrors `Ndarray::read` in `program/ndarray.cpp`): allocate a tiny +// `host_read=true` staging buffer, `memcpy_internal` the one element from the ndarray's device buffer into +// staging, then map staging to read the value host-side. This works on every backend because every +// `Device` implementation supports `host_read=true` allocations + `map` + `memcpy_internal`. For `kNone` +// numpy passthrough the captured pointer is already host-readable; we read it directly. +int64_t read_diagnose_external_tensor(const SerializedSizeExprNode &node, + const std::vector &resolved_indices, + Program *prog, + const AdStackCache::DiagnoseLaunchSnapshot &snapshot) { + if (node.arg_id_path.empty()) { + return -1; + } + int arg_id = node.arg_id_path[0]; + auto ptr_it = snapshot.data_ptrs.find(arg_id); + if (ptr_it == snapshot.data_ptrs.end() || ptr_it->second == nullptr) { + return -1; + } + auto type_it = snapshot.dev_alloc_types.find(arg_id); + if (type_it == snapshot.dev_alloc_types.end()) { + return -1; + } + auto shape_it = snapshot.shapes.find(arg_id); + if (shape_it == snapshot.shapes.end()) { + return -1; + } + const std::vector &shape = shape_it->second; + // Compose C-order linear offset across resolved indices (mirrors `evaluate_external_tensor_read`'s stride + // math; we cannot share the helper because that one routes through `LaunchContextBuilder::get_struct_arg_host` + // which is not available here). + if (resolved_indices.size() > shape.size() && !shape.empty()) { + // More indices than rank - the size_expr was lowered against a different shape; skip. + return -1; + } + int64_t linear = 0; + int64_t stride = 1; + for (std::size_t i = resolved_indices.size(); i > 0; --i) { + linear += resolved_indices[i - 1] * stride; + if (i - 1 > 0 && i - 1 < shape.size()) { + stride *= static_cast(shape[i - 1]); + } + } + auto prim_dt = static_cast(node.const_value); + std::size_t elem_size = 0; + switch (prim_dt) { + case PrimitiveTypeID::i8: + case PrimitiveTypeID::u8: + elem_size = 1; + break; + case PrimitiveTypeID::i16: + case PrimitiveTypeID::u16: + elem_size = 2; + break; + case PrimitiveTypeID::i32: + case PrimitiveTypeID::u32: + elem_size = 4; + break; + case PrimitiveTypeID::i64: + case PrimitiveTypeID::u64: + elem_size = 8; + break; + default: + return -1; + } + std::size_t byte_offset = static_cast(linear) * elem_size; + // Decode the captured pointer to host bytes. + std::vector staging_bytes(elem_size); + if (type_it->second == LaunchContextBuilder::DevAllocType::kNone) { + // Numpy passthrough: ptr is already a raw host pointer. + const uint8_t *src = static_cast(ptr_it->second) + byte_offset; + std::memcpy(staging_bytes.data(), src, elem_size); + } else if (type_it->second == LaunchContextBuilder::DevAllocType::kNdarray) { + if (prog == nullptr) { + return -1; + } + auto *alloc = static_cast(ptr_it->second); + if (alloc == nullptr || alloc->device == nullptr) { + return -1; + } + Device::AllocParams params; + params.host_write = false; + params.host_read = true; + params.size = elem_size; + params.usage = AllocUsage::Storage; + auto [staging, alloc_res] = alloc->device->allocate_memory_unique(params); + if (alloc_res != RhiResult::success || !staging) { + return -1; + } + alloc->device->memcpy_internal(staging->get_ptr(), alloc->get_ptr(byte_offset), elem_size); + void *mapped = nullptr; + if (alloc->device->map(*staging, &mapped) != RhiResult::success || mapped == nullptr) { + return -1; + } + std::memcpy(staging_bytes.data(), mapped, elem_size); + alloc->device->unmap(*staging); + } else { + return -1; + } + // Sign- / zero-extend to int64 according to the captured primitive type. + switch (prim_dt) { + case PrimitiveTypeID::i8: + return static_cast(*reinterpret_cast(staging_bytes.data())); + case PrimitiveTypeID::u8: + return static_cast(*reinterpret_cast(staging_bytes.data())); + case PrimitiveTypeID::i16: + return static_cast(*reinterpret_cast(staging_bytes.data())); + case PrimitiveTypeID::u16: + return static_cast(*reinterpret_cast(staging_bytes.data())); + case PrimitiveTypeID::i32: + return static_cast(*reinterpret_cast(staging_bytes.data())); + case PrimitiveTypeID::u32: + return static_cast(*reinterpret_cast(staging_bytes.data())); + case PrimitiveTypeID::i64: + return *reinterpret_cast(staging_bytes.data()); + case PrimitiveTypeID::u64: + return static_cast(*reinterpret_cast(staging_bytes.data())); + default: + return -1; + } +} + +// Mirror of `evaluate_node` for diagnose-time evaluation. Same tree-walk semantics; differs only in the leaf +// case for `ExternalTensorRead` / `ExternalTensorShape`, which route through the snapshot + `Device::map` path +// instead of `LaunchContextBuilder`. Returns -1 on any leaf-resolution failure to short-circuit the rest of +// the walk and let the caller fall back to the static dual-cause body. +int64_t evaluate_node_for_diagnose(const SerializedSizeExpr &expr, + int32_t node_idx, + std::unordered_map &bound_vars, + Program *prog, + const AdStackCache::DiagnoseLaunchSnapshot &snapshot) { + if (node_idx < 0 || static_cast(node_idx) >= expr.nodes.size()) { + return -1; + } + const auto &node = expr.nodes[node_idx]; + switch (static_cast(node.kind)) { + case SizeExpr::Kind::Const: + return node.const_value; + case SizeExpr::Kind::FieldLoad: { + // Field reads stay on the existing host path - they do not depend on `LaunchContextBuilder` and the + // SNode reader-kernel dispatch is host-driven. We pass `nullptr` ReadSink so the recorded observations + // do not leak into the cache from a diagnose-only walk. + return evaluate_field_load(node, bound_vars, prog, /*reads=*/nullptr); + } + case SizeExpr::Kind::Add: { + int64_t a = evaluate_node_for_diagnose(expr, node.operand_a, bound_vars, prog, snapshot); + int64_t b = evaluate_node_for_diagnose(expr, node.operand_b, bound_vars, prog, snapshot); + if (a < 0 || b < 0) { + return -1; + } + return a + b; + } + case SizeExpr::Kind::Sub: { + int64_t a = evaluate_node_for_diagnose(expr, node.operand_a, bound_vars, prog, snapshot); + int64_t b = evaluate_node_for_diagnose(expr, node.operand_b, bound_vars, prog, snapshot); + if (a < 0 || b < 0) { + return -1; + } + return std::max(a - b, 0); + } + case SizeExpr::Kind::Mul: { + int64_t a = evaluate_node_for_diagnose(expr, node.operand_a, bound_vars, prog, snapshot); + int64_t b = evaluate_node_for_diagnose(expr, node.operand_b, bound_vars, prog, snapshot); + if (a < 0 || b < 0) { + return -1; + } + return a * b; + } + case SizeExpr::Kind::Max: { + int64_t a = evaluate_node_for_diagnose(expr, node.operand_a, bound_vars, prog, snapshot); + int64_t b = evaluate_node_for_diagnose(expr, node.operand_b, bound_vars, prog, snapshot); + if (a < 0 || b < 0) { + return -1; + } + return std::max(a, b); + } + case SizeExpr::Kind::MaxOverRange: { + int64_t begin = evaluate_node_for_diagnose(expr, node.operand_a, bound_vars, prog, snapshot); + int64_t end = evaluate_node_for_diagnose(expr, node.operand_b, bound_vars, prog, snapshot); + if (begin < 0 || end < 0) { + return -1; + } + // Same iteration cap as the live evaluator; refusing to enumerate prevents diagnose from stalling + // the error path on a pathological trip count. + constexpr int64_t kMaxOverRangeIterations = int64_t{1} << 24; + if (end > begin && end - begin > kMaxOverRangeIterations) { + return -1; + } + int64_t result = 0; + auto prev_it = bound_vars.find(node.var_id); + bool had_prev = prev_it != bound_vars.end(); + int64_t prev_val = had_prev ? prev_it->second : 0; + for (int64_t i = begin; i < end; ++i) { + bound_vars[node.var_id] = i; + int64_t v = evaluate_node_for_diagnose(expr, node.body_node_idx, bound_vars, prog, snapshot); + if (v < 0) { + if (had_prev) { + bound_vars[node.var_id] = prev_val; + } else { + bound_vars.erase(node.var_id); + } + return -1; + } + if (v > result) { + result = v; + } + } + if (had_prev) { + bound_vars[node.var_id] = prev_val; + } else { + bound_vars.erase(node.var_id); + } + return result; + } + case SizeExpr::Kind::BoundVariable: { + auto it = bound_vars.find(node.var_id); + if (it == bound_vars.end()) { + return -1; + } + return it->second; + } + case SizeExpr::Kind::ExternalTensorShape: { + if (node.arg_id_path.empty()) { + return -1; + } + int arg_id = node.arg_id_path[0]; + auto shape_it = snapshot.shapes.find(arg_id); + if (shape_it == snapshot.shapes.end()) { + return -1; + } + if (node.arg_shape_axis < 0 || static_cast(node.arg_shape_axis) >= shape_it->second.size()) { + return -1; + } + return static_cast(shape_it->second[node.arg_shape_axis]); + } + case SizeExpr::Kind::ExternalTensorRead: { + // Resolve indices from bound_vars first, then dispatch to the snapshot-aware reader. + std::vector resolved(node.indices.size()); + for (std::size_t i = 0; i < node.indices.size(); ++i) { + int32_t raw = node.indices[i]; + if (raw >= 0) { + resolved[i] = raw; + } else { + int32_t var_id = -(raw + 1); + auto bv = bound_vars.find(var_id); + if (bv == bound_vars.end()) { + return -1; + } + resolved[i] = bv->second; + } + } + return read_diagnose_external_tensor(node, resolved, prog, snapshot); + } + } + return -1; +} + +} // namespace + +int64_t evaluate_adstack_size_expr_for_diagnose(const SerializedSizeExpr &expr, Program *prog) { + if (expr.nodes.empty() || prog == nullptr) { + return -1; + } + const AdStackCache::DiagnoseLaunchSnapshot *snapshot = prog->adstack_cache().get_diagnose_snapshot(); + if (snapshot == nullptr) { + return -1; + } + std::unordered_map bound_vars; + return evaluate_node_for_diagnose(expr, static_cast(expr.nodes.size() - 1), bound_vars, prog, *snapshot); +} + +} // namespace quadrants::lang diff --git a/quadrants/program/adstack/diagnose.h b/quadrants/program/adstack/diagnose.h new file mode 100644 index 0000000000..84cecc2ca2 --- /dev/null +++ b/quadrants/program/adstack/diagnose.h @@ -0,0 +1,21 @@ +#pragma once + +#include + +#include "quadrants/ir/adstack_size_expr.h" + +namespace quadrants::lang { + +class Program; + +// Diagnose-time variant that evaluates the same `SerializedSizeExpr` against the captured +// `AdStackCache::DiagnoseLaunchSnapshot` rather than a live `LaunchContextBuilder`. Used by +// `AdStackCache::diagnose_adstack_overflow` to resolve `ExternalTensorRead` / `ExternalTensorShape` leaves at +// error time against the live (potentially mutated) ndarray contents, without needing the launch ctx that is +// gone by sync time on async backends. The cross-backend `Device::map(*allocation, &host_ptr)` path is the +// design pivot - see `AdStackCache::DiagnoseLaunchSnapshot`'s comment for the rationale (vs. re-dispatching +// the on-device sizer). Returns -1 if any leaf cannot be resolved (e.g. an arg_id missing from the snapshot, +// or an allocation whose `Device::map` fails); callers fall back to the static dual-cause body in that case. +int64_t evaluate_adstack_size_expr_for_diagnose(const SerializedSizeExpr &expr, Program *prog); + +} // namespace quadrants::lang diff --git a/quadrants/program/adstack/eval.cpp b/quadrants/program/adstack/eval.cpp new file mode 100644 index 0000000000..064d14a121 --- /dev/null +++ b/quadrants/program/adstack/eval.cpp @@ -0,0 +1,384 @@ +#include "quadrants/program/adstack/eval.h" + +#include +#include +#include +#include +#include +#include + +#include "quadrants/common/logging.h" +#include "quadrants/ir/snode.h" +#include "quadrants/ir/type.h" +#include "quadrants/ir/type_factory.h" +#include "quadrants/ir/type_utils.h" +#include "quadrants/program/launch_context_builder.h" +#include "quadrants/program/program.h" +#include "quadrants/program/snode_rw_accessors_bank.h" + +namespace quadrants::lang { + +namespace { + +using ReadSink = std::vector; + +// Per-launch cache of `FieldLoad` re-reads, keyed by `(snode_id, indices)`. Within one host-side eval root +// call the SNode field values are pinned (no other kernel runs concurrently), so deduping repeats across +// the size-expr trees evaluated in that window is correctness-safe. +struct LaunchScopedReadCache { + struct Key { + int snode_id; + std::vector indices; + bool operator==(const Key &o) const noexcept { + return snode_id == o.snode_id && indices == o.indices; + } + }; + struct KeyHash { + std::size_t operator()(const Key &k) const noexcept { + std::size_t h = std::hash{}(k.snode_id); + for (int v : k.indices) { + h ^= std::hash{}(v) + 0x9e3779b97f4a7c15ull + (h << 6) + (h >> 2); + } + return h; + } + }; + std::unordered_map map; +}; +thread_local LaunchScopedReadCache *t_launch_read_cache = nullptr; + +} // namespace + +int64_t evaluate_field_load(const SerializedSizeExprNode &node, + std::unordered_map &bound_vars, + Program *prog, + ReadSink *reads) { + QD_ASSERT_INFO(node.snode_id >= 0, "SerializedSizeExpr FieldLoad with no snode_id"); + SNode *snode = prog->get_snode_by_id(node.snode_id); + QD_ASSERT_INFO(snode != nullptr, + "SerializedSizeExpr FieldLoad snode_id={} not found in the current program's snode trees", + node.snode_id); + std::vector indices; + indices.reserve(node.indices.size()); + for (int32_t raw : node.indices) { + if (raw >= 0) { + indices.push_back(raw); + } else { + int32_t var_id = -(raw + 1); + auto it = bound_vars.find(var_id); + QD_ASSERT_INFO(it != bound_vars.end(), + "SerializedSizeExpr FieldLoad references unbound var_id={} (the enclosing MaxOverRange " + "node must have bound it before this read)", + var_id); + indices.push_back(static_cast(it->second)); + } + } + int64_t v = read_field_with_launch_cache(node.snode_id, indices, prog); + if (reads != nullptr) { + AdStackCache::SizeExprReadObservation obs; + obs.kind = AdStackCache::SizeExprReadObservation::FieldLoadObs; + obs.snode_id = node.snode_id; + obs.indices = std::move(indices); + obs.arg_shape_axis = 0; + obs.prim_dt = 0; + obs.observed_value = v; + // Snapshot the SNode's write gen so the next replay can fast-skip when no kernel has written this SNode + // since record time (the dominant case for a steady-state reverse-mode loop with stable bounds). + obs.observed_gen = prog->adstack_cache().snode_write_gen(node.snode_id); + reads->push_back(std::move(obs)); + } + return v; +} + +int64_t evaluate_external_tensor_read(const SerializedSizeExprNode &node, + std::unordered_map &bound_vars, + Program *prog, + LaunchContextBuilder *ctx, + ReadSink *reads) { + QD_ASSERT_INFO(ctx != nullptr, + "SerializedSizeExpr ExternalTensorRead evaluated with no LaunchContextBuilder; the launcher " + "must pass the current launch's context in"); + QD_ASSERT_INFO(!node.arg_id_path.empty(), "SerializedSizeExpr ExternalTensorRead has empty arg_id_path"); + int arg_id = node.arg_id_path[0]; + ArgArrayPtrKey key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; + auto it = ctx->array_ptrs.find(key); + QD_ASSERT_INFO(it != ctx->array_ptrs.end(), + "SerializedSizeExpr ExternalTensorRead: arg {} has no data pointer in launch context", arg_id); + void *data_ptr = it->second; + // Resolve each index (possibly via a bound variable) and compose them into the C-order linear offset + // `sum_i(idx_i * prod_{j>i}(shape_j))`. Multi-dim shapes are read from the launch context through the same + // `SHAPE_POS_IN_NDARRAY` path `ExternalTensorShape` uses, so an ndarray indexed by two or more loop variables lowers + // to the correct element rather than the stride-1 sum `arr_flat[i + j + ...]`. Mirrors the per-axis stride that + // `encode_subtree` precomputes on the SPIR-V path; on CPU the host evaluator is called directly from + // `publish_adstack_metadata`, so the stride math has to live here too. + std::vector resolved(node.indices.size()); + for (std::size_t i = 0; i < node.indices.size(); ++i) { + int32_t raw = node.indices[i]; + if (raw >= 0) { + resolved[i] = raw; + } else { + int32_t var_id = -(raw + 1); + auto bv = bound_vars.find(var_id); + QD_ASSERT_INFO(bv != bound_vars.end(), "SerializedSizeExpr ExternalTensorRead references unbound var_id={}", + var_id); + resolved[i] = bv->second; + } + } + int64_t linear = 0; + int64_t stride = 1; + for (std::size_t i = node.indices.size(); i > 0; --i) { + linear += resolved[i - 1] * stride; + if (i - 1 > 0) { + std::vector sh_idx(node.arg_id_path.begin(), node.arg_id_path.end()); + sh_idx.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY); + sh_idx.push_back(static_cast(i - 1)); + // Ndarray shapes are `int32` in the args struct (same convention `evaluate_external_tensor_shape` relies on); + // reading as `int64` would sign-extend the adjacent slot into the shape and produce garbage strides. + stride *= static_cast(ctx->get_struct_arg_host(sh_idx)); + } + } + auto prim_dt = static_cast(node.const_value); + int64_t v; + switch (prim_dt) { + case PrimitiveTypeID::i32: + v = static_cast(static_cast(data_ptr)[linear]); + break; + case PrimitiveTypeID::i64: + v = static_cast(data_ptr)[linear]; + break; + case PrimitiveTypeID::u32: + v = static_cast(static_cast(data_ptr)[linear]); + break; + case PrimitiveTypeID::u64: + v = static_cast(static_cast(data_ptr)[linear]); + break; + case PrimitiveTypeID::i16: + v = static_cast(static_cast(data_ptr)[linear]); + break; + case PrimitiveTypeID::u16: + v = static_cast(static_cast(data_ptr)[linear]); + break; + case PrimitiveTypeID::i8: + v = static_cast(static_cast(data_ptr)[linear]); + break; + case PrimitiveTypeID::u8: + v = static_cast(static_cast(data_ptr)[linear]); + break; + default: + QD_ERROR("SerializedSizeExpr ExternalTensorRead: unsupported element type {}", node.const_value); + v = 0; + } + if (reads != nullptr) { + AdStackCache::SizeExprReadObservation obs; + obs.kind = AdStackCache::SizeExprReadObservation::ExternalReadObs; + obs.snode_id = 0; + obs.indices.reserve(resolved.size()); + for (auto r : resolved) + obs.indices.push_back(static_cast(r)); + obs.arg_id_path = node.arg_id_path; + obs.arg_shape_axis = 0; + obs.prim_dt = static_cast(prim_dt); + obs.observed_value = v; + obs.observed_devalloc = data_ptr; + if (prog != nullptr) { + // Snapshot the ndarray's data gen so the next replay can fast-skip when no kernel / Ndarray API write + // has touched the underlying buffer since record time. Mirrors the FieldLoad fast-skip; covers the same + // steady-state hot path for ndarray-bounded reverse-mode loops. + obs.observed_gen = prog->adstack_cache().ndarray_data_gen(data_ptr); + } + reads->push_back(std::move(obs)); + } + return v; +} + +int64_t evaluate_external_tensor_shape(const SerializedSizeExprNode &node, LaunchContextBuilder *ctx, ReadSink *reads) { + QD_ASSERT_INFO(ctx != nullptr, + "SerializedSizeExpr ExternalTensorShape evaluated with no LaunchContextBuilder; the launcher " + "must pass the current launch's context into the evaluator to resolve ndarray shapes"); + std::vector arg_indices(node.arg_id_path.begin(), node.arg_id_path.end()); + arg_indices.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY); + arg_indices.push_back(node.arg_shape_axis); + // Ndarray shape slots are `int32` in the args struct (same convention `evaluate_external_tensor_read` relies + // on for its stride multiplies). Using `int64` here reads 8 bytes past the slot and sign-extends the next + // field into the shape, so a user-visible downstream effect is that any `SizeExpr` node that feeds a + // shape-derived value into a trip count (e.g. `MaxOverRange(0, ExtShape, ...)`) evaluates its range as + // garbage - often zero when the adjacent field is zero-initialised - and the containing tree collapses to + // zero. The adstack max_size is clamped to 1 on a zero tree result, which under-bounds real push counts and + // trips an overflow assertion at the next `qd.sync()`. + int64_t v = static_cast(ctx->get_struct_arg_host(arg_indices)); + if (reads != nullptr) { + AdStackCache::SizeExprReadObservation obs; + obs.kind = AdStackCache::SizeExprReadObservation::ExternalShapeObs; + obs.snode_id = 0; + obs.arg_id_path = node.arg_id_path; + obs.arg_shape_axis = node.arg_shape_axis; + obs.prim_dt = 0; + obs.observed_value = v; + reads->push_back(std::move(obs)); + } + return v; +} + +int64_t evaluate_node(const SerializedSizeExpr &expr, + int32_t node_idx, + std::unordered_map &bound_vars, + Program *prog, + LaunchContextBuilder *ctx, + ReadSink *reads) { + QD_ASSERT_INFO(node_idx >= 0 && static_cast(node_idx) < expr.nodes.size(), + "SerializedSizeExpr node_idx {} out of bounds (size={})", node_idx, expr.nodes.size()); + const auto &node = expr.nodes[node_idx]; + switch (static_cast(node.kind)) { + case SizeExpr::Kind::Const: + return node.const_value; + case SizeExpr::Kind::FieldLoad: + return evaluate_field_load(node, bound_vars, prog, reads); + case SizeExpr::Kind::Add: + return evaluate_node(expr, node.operand_a, bound_vars, prog, ctx, reads) + + evaluate_node(expr, node.operand_b, bound_vars, prog, ctx, reads); + case SizeExpr::Kind::Sub: + return std::max(evaluate_node(expr, node.operand_a, bound_vars, prog, ctx, reads) - + evaluate_node(expr, node.operand_b, bound_vars, prog, ctx, reads), + 0); + case SizeExpr::Kind::Mul: + return evaluate_node(expr, node.operand_a, bound_vars, prog, ctx, reads) * + evaluate_node(expr, node.operand_b, bound_vars, prog, ctx, reads); + case SizeExpr::Kind::Max: + return std::max(evaluate_node(expr, node.operand_a, bound_vars, prog, ctx, reads), + evaluate_node(expr, node.operand_b, bound_vars, prog, ctx, reads)); + case SizeExpr::Kind::MaxOverRange: { + int64_t begin = evaluate_node(expr, node.operand_a, bound_vars, prog, ctx, reads); + int64_t end = evaluate_node(expr, node.operand_b, bound_vars, prog, ctx, reads); + // Guard against pathological trip counts. The evaluator walks `[begin, end)` linearly and re-evaluates the + // body at every i; a range of several million would stall the launch hot path for seconds. Real reverse-mode + // trip counts sit well below this cap (a few hundred to a few thousand in practice); anything above is + // almost certainly a pre-pass grammar bug the user should file, and a clear QD_ERROR beats a silent hang. + constexpr int64_t kMaxOverRangeIterations = int64_t{1} << 24; + QD_ERROR_IF(end > begin && end - begin > kMaxOverRangeIterations, + "SerializedSizeExpr MaxOverRange iteration count {} exceeds the {} guard; refusing to enumerate. " + "Shrink the enclosing reverse-mode loop or restructure the `SizeExpr` source kernel.", + end - begin, kMaxOverRangeIterations); + int64_t result = 0; + // Bind `var_id` in `bound_vars` for the duration of the loop and restore the outer-scope value (or erase, if + // there was none) before returning, so nested `MaxOverRange` bindings of the same `var_id` stay correct without + // cloning the entire map per iteration. + auto prev_it = bound_vars.find(node.var_id); + bool had_prev = prev_it != bound_vars.end(); + int64_t prev_val = had_prev ? prev_it->second : 0; + for (int64_t i = begin; i < end; ++i) { + bound_vars[node.var_id] = i; + int64_t v = evaluate_node(expr, node.body_node_idx, bound_vars, prog, ctx, reads); + if (v > result) { + result = v; + } + } + if (had_prev) { + bound_vars[node.var_id] = prev_val; + } else { + bound_vars.erase(node.var_id); + } + return result; + } + case SizeExpr::Kind::BoundVariable: { + auto it = bound_vars.find(node.var_id); + QD_ASSERT_INFO(it != bound_vars.end(), + "SerializedSizeExpr BoundVariable var_id={} evaluated outside its MaxOverRange scope", + node.var_id); + return it->second; + } + case SizeExpr::Kind::ExternalTensorShape: + return evaluate_external_tensor_shape(node, ctx, reads); + case SizeExpr::Kind::ExternalTensorRead: + return evaluate_external_tensor_read(node, bound_vars, prog, ctx, reads); + } + QD_ERROR("unreachable SerializedSizeExpr kind {}", node.kind); + return 0; +} + +int64_t read_field_with_launch_cache(int snode_id, const std::vector &indices, Program *prog) { + SNode *snode = prog->get_snode_by_id(snode_id); + if (snode == nullptr) { + return std::numeric_limits::min(); + } + if (t_launch_read_cache != nullptr) { + LaunchScopedReadCache::Key key{snode_id, indices}; + auto it = t_launch_read_cache->map.find(key); + if (it != t_launch_read_cache->map.end()) { + return it->second; + } + int64_t v = prog->get_snode_rw_accessors_bank().get(snode).read_int(indices); + t_launch_read_cache->map.emplace(std::move(key), v); + return v; + } + return prog->get_snode_rw_accessors_bank().get(snode).read_int(indices); +} + +// Per-thread backing for `SizeExprLaunchScope`. The outer scope on each thread points `t_launch_read_cache` here +// after clearing the map; nested scopes are no-ops. +thread_local LaunchScopedReadCache t_launch_read_cache_storage{}; + +SizeExprLaunchScope::SizeExprLaunchScope() : owns_(t_launch_read_cache == nullptr) { + if (owns_) { + t_launch_read_cache_storage.map.clear(); + t_launch_read_cache = &t_launch_read_cache_storage; + } +} +SizeExprLaunchScope::~SizeExprLaunchScope() { + if (owns_) { + t_launch_read_cache = nullptr; + } +} + +int64_t evaluate_adstack_size_expr_no_cache(const SerializedSizeExpr &expr, Program *prog, LaunchContextBuilder *ctx) { + if (expr.nodes.empty()) { + return -1; + } + SizeExprLaunchScope local_scope; + std::unordered_map empty_bound_vars; + std::vector reads; + return evaluate_node(expr, static_cast(expr.nodes.size() - 1), empty_bound_vars, prog, ctx, &reads); +} + +int64_t evaluate_adstack_size_expr(const SerializedSizeExpr &expr, Program *prog, LaunchContextBuilder *ctx) { + if (expr.nodes.empty()) { + return -1; + } + // Open a `SizeExprLaunchScope` if no enclosing one is active, so repeated reads within this eval share + // the launch read cache. Callers that issue several `evaluate_adstack_size_expr` calls back-to-back + // should open their own scope to span all of them. + SizeExprLaunchScope local_scope; + + // Cache fast path: replay the recorded reads against the live state and reuse the cached result if + // every input still matches. The full walk runs only on cache miss. + if (prog != nullptr) { + int64_t cached; + if (prog->adstack_cache().try_size_expr_cache_hit(prog, &expr, ctx, cached)) { + return cached; + } + } + std::unordered_map empty_bound_vars; + std::vector reads; + int64_t result = + evaluate_node(expr, static_cast(expr.nodes.size() - 1), empty_bound_vars, prog, ctx, &reads); + if (prog != nullptr) { + prog->adstack_cache().record_size_expr_eval(&expr, result, std::move(reads)); + } + return result; +} + +int64_t evaluate_adstack_size_expr_at_node(const SerializedSizeExpr &expr, + int32_t node_idx, + Program *prog, + LaunchContextBuilder *ctx) { + if (node_idx < 0 || static_cast(node_idx) >= expr.nodes.size()) { + return -1; + } + // The recognizer grammar guarantees the subtree at `node_idx` is closed (no outer-scope `BoundVariable` references), + // so an empty bound-vars map is sufficient. Read observations are not recorded - the caller (max-reducer launcher) + // does its own observation tracking via `AdStackCache::record_max_reducer_eval` against the spec key, not the + // per-`SerializedSizeExpr` key the cache uses for `evaluate_adstack_size_expr`. + SizeExprLaunchScope local_scope; + std::unordered_map empty_bound_vars; + return evaluate_node(expr, node_idx, empty_bound_vars, prog, ctx, /*reads=*/nullptr); +} + +} // namespace quadrants::lang diff --git a/quadrants/program/adstack/eval.h b/quadrants/program/adstack/eval.h new file mode 100644 index 0000000000..0c2cd522f0 --- /dev/null +++ b/quadrants/program/adstack/eval.h @@ -0,0 +1,94 @@ +#pragma once + +#include +#include +#include + +#include "quadrants/ir/adstack_size_expr.h" +#include "quadrants/program/adstack/cache.h" + +namespace quadrants::lang { + +class LaunchContextBuilder; +class Program; + +// Evaluates a compile-time captured `SerializedSizeExpr` against the current field state of `prog` and the +// per-launch argument values in `ctx`, returning the concrete adstack capacity for this launch. Scalar i32/i64 +// field loads are serviced by `SNodeRwAccessorsBank` (one reader-kernel dispatch each); ndarray-argument shapes +// are read from `ctx->get_struct_arg`; constants and arithmetic are folded in plain C++; `MaxOverRange` +// enumerates its range and takes the max of the body expression across the bound variable. Returns -1 when the +// expression is empty (no symbolic bound captured), signalling to the caller to use the compile-time fallback. +int64_t evaluate_adstack_size_expr(const SerializedSizeExpr &expr, Program *prog, LaunchContextBuilder *ctx); +// Variant of `evaluate_adstack_size_expr` that bypasses `size_expr_cache_`. Used by the host-eval branch of the +// per-task sizer when feeding a stack-local substituted tree (the `size_expr_cache_` is keyed by `SerializedSizeExpr +// *`, so a transient stack address would alias unrelated cache entries across launches and return wrong cached values). +// Callers that need cache-warmed evaluation should use `evaluate_adstack_size_expr` with the original tree's stable +// pointer. +int64_t evaluate_adstack_size_expr_no_cache(const SerializedSizeExpr &expr, Program *prog, LaunchContextBuilder *ctx); + +// Sub-tree variant of `evaluate_adstack_size_expr`: evaluates the subtree rooted at `node_idx` instead of the full +// tree's root. Used by the max-reducer launcher to host-resolve a captured spec's `begin` / `end` subtrees against the +// live ctx (The recognizer grammar guarantees both subtrees are closed-form, so the recursive evaluator never re-enters +// a `MaxOverRange`). Returns -1 when `node_idx` is out of range; -1 from a deeper host-eval failure propagates the same +// way as in the full-tree variant. +int64_t evaluate_adstack_size_expr_at_node(const SerializedSizeExpr &expr, + int32_t node_idx, + Program *prog, + LaunchContextBuilder *ctx); + +// Diagnose-time variant that evaluates the same `SerializedSizeExpr` against the captured +// `AdStackCache::DiagnoseLaunchSnapshot` rather than a live `LaunchContextBuilder`. Used by +// `AdStackCache::diagnose_adstack_overflow` to resolve `ExternalTensorRead` / `ExternalTensorShape` leaves at +// error time against the live (potentially mutated) ndarray contents, without needing the launch ctx that is +// gone by sync time on async backends. The cross-backend `Device::map(*allocation, &host_ptr)` path is the +// design pivot - see `AdStackCache::DiagnoseLaunchSnapshot`'s comment for the rationale (vs. re-dispatching +// the on-device sizer). Returns -1 if any leaf cannot be resolved (e.g. an arg_id missing from the snapshot, +// or an allocation whose `Device::map` fails); callers fall back to the static dual-cause body in that case. +int64_t evaluate_adstack_size_expr_for_diagnose(const SerializedSizeExpr &expr, Program *prog); + +// RAII guard opening a thread-local read-cache scope. Every nested `evaluate_adstack_size_expr` running inside the +// scope shares one cache, so repeated `(snode_id, indices)` reads share a single reader-kernel dispatch. Place around +// any block that calls `evaluate_adstack_size_expr` more than once back-to-back. +class SizeExprLaunchScope { + public: + SizeExprLaunchScope(); + ~SizeExprLaunchScope(); + SizeExprLaunchScope(const SizeExprLaunchScope &) = delete; + SizeExprLaunchScope &operator=(const SizeExprLaunchScope &) = delete; + + private: + bool owns_; +}; + +// Internal helper exposed for cross-TU use by `quadrants/program/adstack/device_bytecode.cpp` (host-fold path of +// `encode_subtree`) and `quadrants/program/adstack/diagnose.cpp` (FieldLoad delegation in the diagnose evaluator). The +// recursive walker visits every node kind; `bound_vars` carries the live `MaxOverRange` bindings; `reads`, when +// non-null, accumulates `SizeExprReadObservation` entries for cache invalidation. +int64_t evaluate_node(const SerializedSizeExpr &expr, + int32_t node_idx, + std::unordered_map &bound_vars, + Program *prog, + LaunchContextBuilder *ctx, + std::vector *reads); + +// Internal helper exposed for the diagnose evaluator and the max-reducer body encoder. Resolves a `FieldLoad` leaf via +// `SNodeRwAccessorsBank::read_int` plus the launch-scoped read cache, optionally appending a `FieldLoadObs` record to +// `reads`. +int64_t evaluate_field_load(const SerializedSizeExprNode &node, + std::unordered_map &bound_vars, + Program *prog, + std::vector *reads); + +// Internal helper exposed for the max-reducer body encoder's `ExternalTensorShape` host-fold path. Reads the matching +// shape slot from the kernel arg buffer via `LaunchContextBuilder::get_struct_arg_host` and (when `reads` is non-null) +// appends an `ExternalShapeObs` observation. +int64_t evaluate_external_tensor_shape(const SerializedSizeExprNode &node, + LaunchContextBuilder *ctx, + std::vector *reads); + +// Internal helper exposed for the cache replay path (`replay_one_observation`). Reads SNode `snode_id` at `indices` +// through the launch-scoped read cache so multiple size-expr trees evaluated within the same outer launch share a +// single reader-kernel dispatch per `(snode_id, indices)` pair. +int64_t read_field_with_launch_cache(int snode_id, const std::vector &indices, Program *prog); + +} // namespace quadrants::lang diff --git a/quadrants/program/adstack/max_reducer.cpp b/quadrants/program/adstack/max_reducer.cpp new file mode 100644 index 0000000000..8fe93e3686 --- /dev/null +++ b/quadrants/program/adstack/max_reducer.cpp @@ -0,0 +1,601 @@ +#include "quadrants/program/adstack/max_reducer.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "quadrants/common/logging.h" +#include "quadrants/ir/adstack_size_expr_device.h" +#include "quadrants/ir/snode.h" +#include "quadrants/ir/type.h" +#include "quadrants/ir/type_factory.h" +#include "quadrants/program/adstack/device_bytecode.h" +#include "quadrants/program/adstack/eval.h" +#include "quadrants/program/launch_context_builder.h" +#include "quadrants/program/program.h" + +namespace quadrants::lang { + +namespace { + +// True iff the body subtree rooted at `node_idx` references only `Const`, `ExternalTensorRead(arg, [...])` whose every +// index slot is either a non-negative literal constant or `-(v + 1)` for some `v` in `expected_var_ids`, and `Add` / +// `Sub` / `Mul` / `Max` of those. Multi-axis ndarray reads are allowed; multiple distinct bound variables from a +// captured chain of nested `MaxOverRange`s are allowed. The encoder folds the per-axis strides via the live +// `LaunchContextBuilder` shape reads. +bool max_reducer_body_is_recognizable(const SerializedSizeExpr &expr, + int32_t node_idx, + const std::vector &expected_var_ids) { + if (node_idx < 0 || static_cast(node_idx) >= expr.nodes.size()) { + return false; + } + const auto &n = expr.nodes[node_idx]; + switch (static_cast(n.kind)) { + case SizeExpr::Kind::Const: + return true; + case SizeExpr::Kind::ExternalTensorRead: { + if (n.indices.empty()) { + return false; + } + // Reject `i64` / `u64` body leaves. The cache invalidation scheme stores `INT64_MIN` in the recorded observation + // as a "stale" sentinel and revalidates on launch by re-reading the live ndarray and comparing against the saved + // value. A 64-bit leaf can legally produce `INT64_MIN` (= 0x80000000_00000000 bit pattern) on host re-read, which + // would make a mutated cache entry compare equal to the sentinel and false-hit. Restrict to dtypes whose value + // range cannot overlap the sentinel; the device interpreter's `device_load_element` widens any sub-i64 integer + // load to i64 via sign- or zero-extension, so this restriction does not lose any reverse-mode trip-count workload + // (trip counts are uniformly i32 / u32 in practice). The per-task sizer's existing capped path absorbs anything + // outside this dtype set. + const auto leaf_dt = static_cast(n.const_value); + switch (leaf_dt) { + case PrimitiveTypeID::i8: + case PrimitiveTypeID::i16: + case PrimitiveTypeID::i32: + case PrimitiveTypeID::u8: + case PrimitiveTypeID::u16: + case PrimitiveTypeID::u32: + break; + default: + return false; + } + for (int32_t raw : n.indices) { + if (raw >= 0) { + continue; // literal constant axis index + } + const int32_t var_id = -(raw + 1); + bool found = false; + for (int32_t want : expected_var_ids) { + if (want == var_id) { + found = true; + break; + } + } + if (!found) { + return false; // foreign bound variable not bound by the captured chain + } + } + return true; + } + case SizeExpr::Kind::Add: + case SizeExpr::Kind::Sub: + case SizeExpr::Kind::Mul: + case SizeExpr::Kind::Max: + return max_reducer_body_is_recognizable(expr, n.operand_a, expected_var_ids) && + max_reducer_body_is_recognizable(expr, n.operand_b, expected_var_ids); + case SizeExpr::Kind::ExternalTensorShape: + // `ExternalTensorShape` indices are always literal axis numbers (no bound-var references possible), so it is + // unconditionally closed. The encoder host-folds it to `kConst` at encode time via `evaluate_node`. + return true; + case SizeExpr::Kind::FieldLoad: + // `FieldLoad` accepts both literal indices (host-folded by the encoder via `evaluate_field_load` against an empty + // bound-var map and emitted as `kConst`) and bound-variable refs from the captured chain. The latter case lowers + // to a `kFieldLoad` device node whose base pointer is pre-resolved on host (PSB on SPIR-V, `runtime->roots[id] + + // place_byte_offset` on LLVM) and whose per-axis byte strides come from `compute_dense_snode_strides`. Foreign + // bound-var refs (var_ids outside the captured chain) are rejected since the device-side scope only carries the + // chain's axes. + for (int32_t raw : n.indices) { + if (raw >= 0) { + continue; + } + const int32_t var_id = -(raw + 1); + bool found = false; + for (int32_t want : expected_var_ids) { + if (want == var_id) { + found = true; + break; + } + } + if (!found) { + return false; + } + } + return true; + default: + return false; + } +} + +// True iff the bound subtree rooted at `node_idx` evaluates to a closed-form scalar after substituting any +// `MaxOverRange` nodes already captured (`captured_mors`) as `Const`s. Allowed: `Const`, `ExternalTensorShape`, `Add` / +// `Sub` / `Mul` / `Max` of recursively-closed subtrees, and `MaxOverRange` whose node index is in `captured_mors`. On +// success appends every captured-MOR dependency this subtree references to `deps_out`. +bool max_reducer_bound_is_closed(const SerializedSizeExpr &expr, + int32_t node_idx, + const std::unordered_set &captured_mors, + std::vector &deps_out) { + if (node_idx < 0 || static_cast(node_idx) >= expr.nodes.size()) { + return false; + } + const auto &n = expr.nodes[node_idx]; + switch (static_cast(n.kind)) { + case SizeExpr::Kind::Const: + case SizeExpr::Kind::ExternalTensorShape: + return true; + case SizeExpr::Kind::Add: + case SizeExpr::Kind::Sub: + case SizeExpr::Kind::Mul: + case SizeExpr::Kind::Max: + return max_reducer_bound_is_closed(expr, n.operand_a, captured_mors, deps_out) && + max_reducer_bound_is_closed(expr, n.operand_b, captured_mors, deps_out); + case SizeExpr::Kind::MaxOverRange: { + if (captured_mors.count(node_idx) == 0) { + return false; + } + deps_out.push_back(node_idx); + return true; + } + default: + return false; // FieldLoad, BoundVariable from a non-immediately-enclosing scope, ExternalTensorRead, etc. + } +} + +} // namespace + +std::vector recognize_adstack_max_reducer_specs( + const std::vector &size_exprs) { + std::vector specs; + for (std::size_t stack_id = 0; stack_id < size_exprs.size(); ++stack_id) { + const auto &expr = size_exprs[stack_id]; + // `SerializedSizeExpr` is built post-order so deeper `MaxOverRange` nodes always have a smaller `n` than the outer + // `MaxOverRange` that depends on them. Iterating ascending `n` visits dependencies before dependants and + // `captured_mors` is always populated in the right order for `max_reducer_bound_is_closed`. The walk also tracks + // which `MaxOverRange` nodes have been absorbed as the inner axis of an outer multi-axis spec; those are not + // captured separately. + std::unordered_set captured_mors; + std::unordered_set absorbed_as_inner_axis; + for (std::size_t n = 0; n < expr.nodes.size(); ++n) { + const auto &node = expr.nodes[n]; + if (static_cast(node.kind) != SizeExpr::Kind::MaxOverRange) { + continue; + } + if (absorbed_as_inner_axis.count(static_cast(n)) != 0) { + continue; // this node is the inner axis of a multi-axis spec captured at an outer node + } + // Greedy chain capture: starting from `n` (the outermost candidate), descend through nested `MaxOverRange` bodies + // as long as each inner `MaxOverRange`'s `[begin, end)` is closed-form (only depends on `Const` / + // `ExternalTensorShape` / captured-deeper-MORs). Each layer adds one axis. Stop at the first non-MaxOverRange + // body or the first inner `MaxOverRange` whose ranges depend on a chain-bound variable (ragged iteration is not + // supported by the rectangular cross-product dispatch). + std::vector chain_node_idxs; + std::vector chain_var_ids; + std::vector chain_begins; + std::vector chain_ends; + std::vector deps; + int32_t cur = static_cast(n); + while (true) { + const auto &cur_node = expr.nodes[cur]; + if (static_cast(cur_node.kind) != SizeExpr::Kind::MaxOverRange) { + break; + } + if (!max_reducer_bound_is_closed(expr, cur_node.operand_a, captured_mors, deps)) { + break; + } + if (!max_reducer_bound_is_closed(expr, cur_node.operand_b, captured_mors, deps)) { + break; + } + chain_node_idxs.push_back(cur); + chain_var_ids.push_back(cur_node.var_id); + chain_begins.push_back(cur_node.operand_a); + chain_ends.push_back(cur_node.operand_b); + cur = cur_node.body_node_idx; + } + if (chain_node_idxs.empty()) { + continue; // outermost candidate failed the bound-closed check + } + if (!max_reducer_body_is_recognizable(expr, cur, chain_var_ids)) { + continue; // body grammar rejects + } + StaticAdStackMaxReducerSpec spec; + spec.stack_id = static_cast(stack_id); + spec.mor_node_idx = chain_node_idxs.front(); + spec.body_node_idx = cur; + spec.axis_var_ids = std::move(chain_var_ids); + spec.axis_begin_node_idxs = std::move(chain_begins); + spec.axis_end_node_idxs = std::move(chain_ends); + spec.dependent_mor_node_idxs = std::move(deps); + specs.push_back(std::move(spec)); + captured_mors.insert(chain_node_idxs.front()); + // Mark the inner axes as absorbed so the outer loop does not re-capture them as standalone specs. + for (std::size_t i = 1; i < chain_node_idxs.size(); ++i) { + absorbed_as_inner_axis.insert(chain_node_idxs[i]); + } + } + } + return specs; +} + +EncodedMaxReducerBody encode_max_reducer_body_bytecode( + const SerializedSizeExpr &expr, + int32_t body_node_idx, + const std::vector &bound_var_ids, + const std::function &arg_id_path)> &arg_buffer_offset_resolver, + LaunchContextBuilder *ctx, + Program *prog, + const FieldLoadDeviceEmitter *fl_emitter) { + EncodedMaxReducerBody out; + if (body_node_idx < 0 || static_cast(body_node_idx) >= expr.nodes.size()) { + return out; + } + // Post-order DFS to collect reachable node indices from `body_node_idx`. The recognizer grammar guarantees no + // `kMaxOverRange` in the body subtree, so we only need to follow `operand_a` / `operand_b` (binary ops); + // `kExternalTensorRead` / `kExternalTensorShape` / `kFieldLoad` are leaves (their operand fields are unused by the + // device interpreter). The resulting `post_order` vector is sorted such that any node's operands precede the node + // itself. + std::vector post_order; + std::unordered_map old_to_new; // old idx -> dense [0, body_node_count) + std::function visit = [&](int32_t idx) { + if (idx < 0 || old_to_new.count(idx) != 0) { + return; + } + const auto &n = expr.nodes[idx]; + auto kind = static_cast(n.kind); + if (kind == SizeExpr::Kind::Add || kind == SizeExpr::Kind::Sub || kind == SizeExpr::Kind::Mul || + kind == SizeExpr::Kind::Max) { + visit(n.operand_a); + visit(n.operand_b); + } + // `kConst`, `kBoundVariable`, `kExternalTensorRead`, `kExternalTensorShape`, `kFieldLoad` are leaves (the latter + // three are host-folded to `kConst` below; their operand fields hold metadata, not subtree pointers). + int32_t new_idx = static_cast(post_order.size()); + old_to_new[idx] = new_idx; + post_order.push_back(idx); + }; + visit(body_node_idx); + + out.body_node_count = static_cast(post_order.size()); + + // Build the flat indices table for any `kExternalTensorRead` leaves. Each leaf carries `indices_count` axes; each + // axis contributes one `(idx_raw, elem_stride)` pair. `idx_raw` mirrors the host SerializedSizeExprNode encoding + // (`-(var_id + 1)` for bound-var refs, non-negative for constants); the encoder remaps every captured chain bound-var + // ref to a dense device-scope slot in `[0, bound_var_ids.size())` (axis 0 = outermost MaxOverRange = device-scope + // slot 0, axis 1 = next-inner = slot 1, ...). The dispatch site pre-populates each scope slot per iteration of the + // cross-product. `elem_stride` is folded against the live ndarray shape, matching the per-task sizer encoder's + // stride-emission pattern. + std::vector indices_table; + // Map each host-side bound-var id in the captured chain to its dense device-scope slot. + auto remap_chain_var = [&](int32_t host_var_id) -> int32_t { + for (std::size_t k = 0; k < bound_var_ids.size(); ++k) { + if (bound_var_ids[k] == host_var_id) { + return static_cast(k); + } + } + return -1; + }; + // Build `AdStackSizeExprDeviceNode`s in post-order. We only emit fields the device interpreter reads for the + // recognized grammar; unused fields stay at their default values. + std::vector device_nodes(post_order.size()); + for (std::size_t i = 0; i < post_order.size(); ++i) { + const auto &src = expr.nodes[post_order[i]]; + auto &dst = device_nodes[i]; + dst.var_id = -1; + auto kind = static_cast(src.kind); + // Map the host `SizeExpr::Kind` enum into the device-side `AdStackSizeExprDeviceKind` enum: the two enums use + // different integer values (e.g. host `ExternalTensorRead = 9` vs. device `kExternalTensorRead = 7`), so a raw + // assignment lands every body node in the device interpreter's switch default and returns 0 on every walk. Mirror + // the explicit translation the per-task adstack-sizer encoder does (search `AdStackSizeExprDeviceKind::` in this TU + // for the canonical pattern); the max-reducer body grammar narrows to the subset listed below. + switch (kind) { + case SizeExpr::Kind::Const: + dst.kind = static_cast(AdStackSizeExprDeviceKind::kConst); + dst.const_value = src.const_value; + break; + case SizeExpr::Kind::BoundVariable: { + // Device-side scope holds the captured chain bound variables at slots `[0, bound_var_ids.size())`, + // outermost-first. The runtime function / SPIR-V max-reducer shader pre-populates each slot with the current + // cross-product index before walking the body bytecode. + dst.kind = static_cast(AdStackSizeExprDeviceKind::kBoundVariable); + const int32_t slot = remap_chain_var(src.var_id); + if (slot < 0) { + // Foreign bound var leaked past the recognizer; signal failure rather than silently aliasing slot 0. + return EncodedMaxReducerBody{}; + } + dst.var_id = slot; + break; + } + case SizeExpr::Kind::ExternalTensorRead: { + dst.kind = static_cast(AdStackSizeExprDeviceKind::kExternalTensorRead); + dst.prim_dt = static_cast(src.const_value); + // Resolve `arg_buffer_offset` from `arg_id_path` via the caller's resolver. + std::vector path = src.arg_id_path; + const int32_t arg_buf_off = arg_buffer_offset_resolver(path); + if (arg_buf_off < 0) { + return EncodedMaxReducerBody{}; // resolver failed; signal empty result + } + dst.arg_buffer_offset = arg_buf_off; + // Indices table: emit `(idx_raw, elem_stride)` per axis. Bound-variable refs (`-(this_var_id + 1)` in the host + // tree) become `-1` so the device-side scope's single-bound-var slot resolves them. Non-negative entries pass + // through as compile-time literal indices. Per-axis element strides are folded against the live ndarray shape + // read from `ctx->args_type` (the same `SHAPE_POS_IN_NDARRAY` path the per-task sizer encoder and the host + // `evaluate_external_tensor_read` use). `ctx == nullptr` falls back to flat strides; in that mode multi-axis + // reads are encoded as if they were single-axis, which is correct only for `indices.size() == 1` callers. + const int32_t indices_off = static_cast(indices_table.size()); + const std::size_t n_axes = src.indices.size(); + std::vector elem_strides(n_axes, 1); + if (n_axes > 1 && ctx != nullptr) { + for (std::size_t k = n_axes; k-- > 0;) { + if (k + 1 < n_axes) { + std::vector sh_idx(src.arg_id_path.begin(), src.arg_id_path.end()); + sh_idx.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY); + sh_idx.push_back(static_cast(k + 1)); + const int32_t sh = ctx->get_struct_arg_host(sh_idx); + elem_strides[k] = elem_strides[k + 1] * sh; + } + } + } + for (std::size_t a = 0; a < n_axes; ++a) { + int64_t raw = src.indices[a]; + int32_t emit_raw; + if (raw >= 0) { + emit_raw = static_cast(raw); + } else { + const int32_t host_var_id = static_cast(-(raw + 1)); + const int32_t slot = remap_chain_var(host_var_id); + if (slot < 0) { + // Foreign bound var leaked past the recognizer; analyzer invariant violation. + return EncodedMaxReducerBody{}; + } + // Encode dense device-scope slot as `-(slot + 1)`. The dispatch site / runtime walks the body with + // `scope.values[slot]` pre-populated for the current cross-product iteration. + emit_raw = -(slot + 1); + } + indices_table.push_back(emit_raw); + indices_table.push_back(elem_strides[a]); + } + dst.indices_offset = indices_off; + dst.indices_count = static_cast(n_axes); + + // Record a body observation entry so the caller can populate the cache's read list. The caller fills in + // `observed_value` and `observed_gen` post-eval (we do not have the live ctx here). + AdStackCache::SizeExprReadObservation obs{}; + obs.kind = AdStackCache::SizeExprReadObservation::ExternalReadObs; + obs.snode_id = -1; + obs.arg_id_path = std::vector(src.arg_id_path.begin(), src.arg_id_path.end()); + obs.prim_dt = static_cast(src.const_value); + out.body_reads.push_back(std::move(obs)); + break; + } + case SizeExpr::Kind::ExternalTensorShape: { + // Closed leaf - resolve the shape value host-side at encode time and emit it as a `kConst` so the device + // interpreter never walks `args_type`. The dispatch site re-runs the encoder per launch, so a subsequent launch + // binding a different ndarray re-folds against the live shape; the cache invalidation rides on the + // `ExternalShapeObs` recorded below. + std::vector read_sink; + const int64_t v = evaluate_external_tensor_shape(src, ctx, &read_sink); + dst.kind = static_cast(AdStackSizeExprDeviceKind::kConst); + dst.const_value = v; + for (auto &obs : read_sink) { + out.body_reads.push_back(std::move(obs)); + } + break; + } + case SizeExpr::Kind::FieldLoad: { + if (prog == nullptr) { + return EncodedMaxReducerBody{}; // FieldLoad needs a live Program for snode resolution. + } + // Closed FieldLoad (every index slot is a literal constant) host-folds via `evaluate_field_load` to a `kConst` + // leaf at encode time. The recorded `FieldLoadObs` carries the snode write-gen so a subsequent launch that has + // not bumped the gen replays the cached value, mirroring the `kExternalTensorShape` host-fold path. + bool has_bound_var_index = false; + for (int32_t raw : src.indices) { + if (raw < 0) { + has_bound_var_index = true; + break; + } + } + if (!has_bound_var_index) { + std::unordered_map empty_bound; + std::vector read_sink; + const int64_t v = evaluate_field_load(src, empty_bound, prog, &read_sink); + dst.kind = static_cast(AdStackSizeExprDeviceKind::kConst); + dst.const_value = v; + for (auto &obs : read_sink) { + out.body_reads.push_back(std::move(obs)); + } + break; + } + // Bound-var-indexed FieldLoad: emit a `kFieldLoad` device node that the body interpreter resolves per + // cross-product iteration. Backend-specific base resolution: SPIR-V passes a non-empty `fl_emitter` whose + // `fetch` returns `root_psb + place_byte_offset` (pre-baked PSB address); LLVM passes a null emitter and we + // resolve `(snode_root_id, place_byte_offset)` directly via `prog`, which the LLVM device interpreter then + // resolves at runtime via `runtime->roots[snode_root_id] + place_byte_offset`. Per-axis byte strides come from + // `compute_dense_snode_strides` (units = leaf primitive type, not bytes), shared with the per-task sizer's + // `kFieldLoad` arm. + SNode *snode = prog->get_snode_by_id(src.snode_id); + if (snode == nullptr) { + return EncodedMaxReducerBody{}; + } + auto *prim_ty = snode->dt->cast(); + if (prim_ty == nullptr) { + return EncodedMaxReducerBody{}; + } + // Same dtype restriction as `kExternalTensorRead`: the cache-revalidation sentinel `INT64_MIN` must be + // unreachable from a freshly-loaded leaf value, so reject `i64 / u64` leaves where a mutated cell could legally + // hold the sentinel and false-hit on revalidation. + const auto leaf_dt = prim_ty->type; + switch (leaf_dt) { + case PrimitiveTypeID::i8: + case PrimitiveTypeID::i16: + case PrimitiveTypeID::i32: + case PrimitiveTypeID::u8: + case PrimitiveTypeID::u16: + case PrimitiveTypeID::u32: + break; + default: + return EncodedMaxReducerBody{}; + } + std::vector elem_strides; + if (!compute_dense_snode_strides(snode, &elem_strides)) { + return EncodedMaxReducerBody{}; + } + if (elem_strides.size() != src.indices.size()) { + return EncodedMaxReducerBody{}; + } + int32_t snode_root_id = -1; + int64_t base_or_place_off = 0; + if (fl_emitter != nullptr && !fl_emitter->empty()) { + uint64_t base_psb = 0; + std::vector emitter_strides; + if (!fl_emitter->fetch(snode, &base_psb, &emitter_strides)) { + return EncodedMaxReducerBody{}; + } + base_or_place_off = static_cast(base_psb); + } else { + // LLVM path: store `snode_root_id` in `arg_buffer_offset` (unused by FieldLoad on SPIR-V) and + // `place_byte_offset` in `const_value`. The LLVM device interpreter reads `runtime->roots[snode_root_id] + + // place_byte_offset` and adds the per-axis-stride-weighted element offset. + snode_root_id = snode->get_snode_tree_id(); + base_or_place_off = static_cast(prog->get_field_in_tree_offset(snode_root_id, snode)); + } + dst.kind = static_cast(AdStackSizeExprDeviceKind::kFieldLoad); + dst.prim_dt = static_cast(leaf_dt); + dst.arg_buffer_offset = snode_root_id; + dst.const_value = base_or_place_off; + const int32_t indices_off = static_cast(indices_table.size()); + for (std::size_t a = 0; a < src.indices.size(); ++a) { + int32_t emit_raw; + int64_t raw = src.indices[a]; + if (raw >= 0) { + emit_raw = static_cast(raw); + } else { + const int32_t host_var_id = static_cast(-(raw + 1)); + const int32_t slot = remap_chain_var(host_var_id); + if (slot < 0) { + return EncodedMaxReducerBody{}; + } + emit_raw = -(slot + 1); + } + indices_table.push_back(emit_raw); + indices_table.push_back(elem_strides[a]); + } + dst.indices_offset = indices_off; + dst.indices_count = static_cast(src.indices.size()); + // Push a `FieldLoadObs` skeleton: snode_id is the staleness key; `indices = {}` signals to + // `replay_one_observation`'s FieldLoadObs arm that the gen counter is the sole staleness signal (the body is + // evaluated at every cross-product point so there is no canonical scalar to re-read on a gen mismatch). + // `populate_max_reducer_body_observations` fills in `observed_value` (sentinel) and `observed_gen` at dispatch + // time once a live `AdStackCache` is in scope. + AdStackCache::SizeExprReadObservation obs{}; + obs.kind = AdStackCache::SizeExprReadObservation::FieldLoadObs; + obs.snode_id = src.snode_id; + obs.prim_dt = static_cast(leaf_dt); + out.body_reads.push_back(std::move(obs)); + break; + } + case SizeExpr::Kind::Add: + case SizeExpr::Kind::Sub: + case SizeExpr::Kind::Mul: + case SizeExpr::Kind::Max: { + if (kind == SizeExpr::Kind::Add) { + dst.kind = static_cast(AdStackSizeExprDeviceKind::kAdd); + } else if (kind == SizeExpr::Kind::Sub) { + dst.kind = static_cast(AdStackSizeExprDeviceKind::kSub); + } else if (kind == SizeExpr::Kind::Mul) { + dst.kind = static_cast(AdStackSizeExprDeviceKind::kMul); + } else { + dst.kind = static_cast(AdStackSizeExprDeviceKind::kMax); + } + auto map_op = [&](int32_t old) -> int32_t { + auto it = old_to_new.find(old); + return it == old_to_new.end() ? -1 : it->second; + }; + dst.operand_a = map_op(src.operand_a); + dst.operand_b = map_op(src.operand_b); + break; + } + default: + // Out-of-grammar kind reached the encoder; the caller should have filtered via + // `recognize_adstack_max_reducer_specs`. Return empty to signal failure. + return EncodedMaxReducerBody{}; + } + } + + out.indices_count = static_cast(indices_table.size()); + // Concatenate `[device_nodes][indices_table]` into the output bytes buffer. + const std::size_t nodes_bytes = device_nodes.size() * sizeof(AdStackSizeExprDeviceNode); + const std::size_t indices_bytes = indices_table.size() * sizeof(int32_t); + out.bytes.resize(nodes_bytes + indices_bytes); + if (nodes_bytes > 0) { + std::memcpy(out.bytes.data(), device_nodes.data(), nodes_bytes); + } + if (indices_bytes > 0) { + std::memcpy(out.bytes.data() + nodes_bytes, indices_table.data(), indices_bytes); + } + return out; +} + +SerializedSizeExpr substitute_precomputed_max_over_range(const SerializedSizeExpr &expr, + uint32_t registry_id, + int32_t stack_id, + const MaxReducerResultMap &results) { + if (results.empty()) { + return expr; + } + auto pack_key = [&](std::size_t n) { + return (static_cast(registry_id) & 0xFFFFFFFFull) | + ((static_cast(stack_id) & 0xFFFFull) << 32) | ((static_cast(n) & 0xFFFFull) << 48); + }; + // Cheap precheck: any `MaxOverRange` node in this expr with a key in `results`? If not, return verbatim. + bool any_match = false; + for (std::size_t n = 0; n < expr.nodes.size(); ++n) { + if (static_cast(expr.nodes[n].kind) != SizeExpr::Kind::MaxOverRange) { + continue; + } + if (results.count(pack_key(n)) != 0) { + any_match = true; + break; + } + } + if (!any_match) { + return expr; + } + // Build a copy with substitution applied to matching MaxOverRange nodes. Node count is unchanged so operand + // indices in non-substituted nodes stay valid; substituted nodes become `kConst` leaves whose `const_value` is + // the dispatched max-reducer result. + SerializedSizeExpr out = expr; + for (std::size_t n = 0; n < out.nodes.size(); ++n) { + auto &node = out.nodes[n]; + if (static_cast(node.kind) != SizeExpr::Kind::MaxOverRange) { + continue; + } + auto it = results.find(pack_key(n)); + if (it == results.end()) { + continue; + } + node.kind = static_cast(SizeExpr::Kind::Const); + node.const_value = it->second; + // Defensive cleanup: the host evaluator's `kConst` arm reads only `const_value`. Reset operand / body / + // var_id slots to -1 so any future reader that does not branch on `kind` produces a deterministic failure + // rather than reading stale indices. + node.operand_a = -1; + node.operand_b = -1; + node.body_node_idx = -1; + node.var_id = -1; + } + return out; +} + +} // namespace quadrants::lang diff --git a/quadrants/program/adstack/max_reducer.h b/quadrants/program/adstack/max_reducer.h new file mode 100644 index 0000000000..3f98e519ca --- /dev/null +++ b/quadrants/program/adstack/max_reducer.h @@ -0,0 +1,102 @@ +#pragma once + +#include +#include +#include +#include + +#include "quadrants/ir/adstack_size_expr.h" +#include "quadrants/program/adstack/cache.h" +#include "quadrants/transforms/static_adstack_analysis.h" + +namespace quadrants::lang { + +class LaunchContextBuilder; +class Program; + +// +// Type alias for the max-reducer result map. Keyed by `(registry_id, stack_id, mor_node_idx)` packed via the same +// `pack_max_reducer_key` encoding `AdStackCache::try_max_reducer_cache_hit` uses, so a single map shared between the +// dispatch path and the substitution helper avoids re-packing at every lookup. +using MaxReducerResultMap = std::unordered_map; + +// extract a captured `MaxOverRange`'s body subtree from `expr` and emit it as a flat `[AdStackSizeExprDeviceNode x +// body_node_count][int32 x indices_count]` bytecode blob plus a parallel `[uint8_t]` byte buffer ready to upload to a +// device storage buffer. Reachable nodes are walked in post-order from `body_node_idx` and renumbered to dense `[0, +// body_node_count)` indices; referenced indices entries from `expr.indices_table` (the `idx_raw, elem_stride` pairs +// `kExternalTensorRead` reads) are copied into the same flat buffer at `body_node_count * +// sizeof(AdStackSizeExprDeviceNode)`. Returns the raw bytes plus `body_node_count` and `indices_count` so the caller +// can populate the matching `AdStackMaxReducerParams` / `LlvmAdStackMaxReducerDeviceParams` fields. The recognizer +// grammar guarantees the body subtree contains no `kMaxOverRange` / `kFieldLoad`, so the body interpreter only needs +// the small grammar set the SPIR-V max-reducer shader and the LLVM `runtime_eval_adstack_max_reduce` runtime function +// both implement. +// +// `arg_buffer_offset_resolver` resolves `(arg_id_path) -> byte_offset_in_arg_buffer` for `kExternalTensorRead` leaves. +// On the gfx caller path this is a closure over `LaunchContextBuilder::args_type::get_element_offset` (same path the +// SizeExpr device-bytecode encoder uses). On the LLVM caller path the resolver mirrors the per-task adstack sizer's +// arg-buffer-offset precomputation. Returns `-1` on resolution failure (caller should hard-error or skip the spec). +struct EncodedMaxReducerBody { + std::vector bytes; + uint32_t body_node_count{0}; + uint32_t indices_count{0}; + // Reads observed during encoding: one entry per body leaf (`kExternalTensorRead`) and per begin/end leaf the caller + // resolved separately. Used by `AdStackCache::record_max_reducer_eval` so the next launch can short-circuit on a + // generation match. Caller fills in the begin/end observations and appends body observations from this list. + std::vector body_reads; +}; +// Forward decl: defined in `quadrants/program/adstack/device_bytecode.h`. Including the full header here would create a +// cycle (`device_bytecode.h` already includes this header for `MaxReducerResultMap`). The encoder only references the +// struct's `fetch` field via the pointer parameter so a forward declaration is enough at this site. +struct FieldLoadDeviceEmitter; + +EncodedMaxReducerBody encode_max_reducer_body_bytecode( + const SerializedSizeExpr &expr, + int32_t body_node_idx, + const std::vector &bound_var_ids, + const std::function &arg_id_path)> &arg_buffer_offset_resolver, + LaunchContextBuilder *ctx, + Program *prog, + const FieldLoadDeviceEmitter *fl_emitter = nullptr); + +// Snapshot the live ndarray data pointer + generation counter into each `ExternalReadObs` record. The encoder emits the +// observation skeleton (kind / arg_id_path / prim_dt) but cannot fill in the runtime-resolved `data_ptr` / +// `observed_gen` / `observed_value` because it has no `LaunchContextBuilder`. This helper closes that gap right before +// the max-reducer dispatch site calls `AdStackCache::record_max_reducer_eval`, so the next launch's +// `try_max_reducer_cache_hit` replay can fast-skip on a matching `ndarray_data_gen`. `observed_value` is recorded as +// `INT64_MIN` so the replay's gen-mismatch dereference path returns a value strictly greater than the recorded sentinel +// and forces the cache to invalidate; the cached max itself is stored in `MaxReducerCacheEntry::result`, not in any +// per-leaf observation. +void populate_max_reducer_body_observations(std::vector &reads, + LaunchContextBuilder *ctx, + AdStackCache *cache); + +// walk every per-stack `SerializedSizeExpr` in `size_exprs` post-order and return the list of `MaxOverRange` nodes the +// runtime can reduce in parallel via a dedicated max-reducer dispatch. Each returned spec references its alloca by +// `stack_id` (index into `size_exprs`) and its `MaxOverRange` by `mor_node_idx` (index into +// `size_exprs[stack_id].nodes`). Specs are returned in dependency order: deeper `MaxOverRange` nodes first so the +// runtime can substitute their results before evaluating outer nodes that depend on them. Grammar: +// * `body` subtree references only `Const`, `ExternalTensorRead(arg, [BoundVariable(this_var_id)])`, and `Add` / `Sub` +// / `Mul` / `Max` of those. Single index axis. Integer dtype on every leaf. +// * `begin` and `end` subtrees reference only `Const`, `ExternalTensorShape`, `Add` / `Sub` / `Mul` / `Max`, or another +// `MaxOverRange` already captured deeper in the same tree (becomes a `Const` after substitution). Anything outside the +// grammar is skipped silently; that `MaxOverRange` continues to fall through to the existing capped path (host +// hard-error when `QD_DEBUG_ADSTACK=1`, silent truncation otherwise). +std::vector recognize_adstack_max_reducer_specs( + const std::vector &size_exprs); + +// walk `expr.nodes`, replace every captured `MaxOverRange` node whose `(registry_id, stack_id, mor_node_idx)` is in +// `results` with a `Const` carrying the dispatched value. Other nodes (and their `operand_a` / `operand_b` / +// `body_node_idx` references) are copied through verbatim. The returned `SerializedSizeExpr` has `nodes.size() == +// expr.nodes.size()` (in-place substitution); operand indices in non-substituted nodes remain valid because the count +// is unchanged. +// +// Empty-input fast path: when no captured spec matches this `(registry_id, stack_id)` (computed by checking `results` +// against every `MaxOverRange` node in `expr`), return `expr` unchanged (the caller's reference into the per-stack tree +// stays valid). Use `SerializedSizeExpr` by value as the return so the caller can transparently swap the reference +// depending on whether substitution fired. +SerializedSizeExpr substitute_precomputed_max_over_range(const SerializedSizeExpr &expr, + uint32_t registry_id, + int32_t stack_id, + const MaxReducerResultMap &results); + +} // namespace quadrants::lang diff --git a/quadrants/program/adstack/write_gen.cpp b/quadrants/program/adstack/write_gen.cpp new file mode 100644 index 0000000000..057951cd33 --- /dev/null +++ b/quadrants/program/adstack/write_gen.cpp @@ -0,0 +1,161 @@ +#include "quadrants/program/adstack/write_gen.h" + +#include +#include +#include +#include + +#include "quadrants/codegen/llvm/llvm_compiled_data.h" +#include "quadrants/codegen/spirv/kernel_utils.h" +#include "quadrants/common/logging.h" +#include "quadrants/ir/type_factory.h" +#include "quadrants/program/adstack/cache.h" +#include "quadrants/program/adstack/eval.h" +#include "quadrants/program/launch_context_builder.h" +#include "quadrants/program/program.h" +#include "quadrants/transforms/static_adstack_analysis.h" + +namespace quadrants::lang { + +void clip_effective_rows_by_loop_trip_count(std::size_t &effective_rows, + const StaticAdStackBoundExpr &bound_expr, + std::size_t dispatched_threads_ceiling, + Program *prog, + LaunchContextBuilder *ctx) { + if (bound_expr.loop_iter_static > 0) { + // Compile-time trip count: integer compare, no per-launch eval cost. Constant `SizeExpr` shapes are + // already collapsed into this field by the analyzer so they short-circuit the runtime eval below. + const std::size_t loop_iter_static = static_cast(bound_expr.loop_iter_static); + if (loop_iter_static <= dispatched_threads_ceiling) { + effective_rows = std::min(effective_rows, loop_iter_static); + } + return; + } + if (bound_expr.loop_iter_size_expr.nodes.empty() || prog == nullptr || ctx == nullptr) { + // Runtime tree empty or no resolution context: the analyzer left this field unset for shapes the + // compile-time path could not cover (or the caller did not supply a `Program` / `LaunchContextBuilder`), + // so leave `effective_rows` alone and let the caller fall back to the unclipped reducer count. + return; + } + // Runtime-bounded clip: evaluate the captured trip-count `SizeExpr` only when the static field is unset + // (the analyzer leaves `loop_iter_static == 0` for shapes the compile-time path cannot cover, e.g. + // `for j in range(field[i])` / `for k in range(arr.shape[axis])`). Cost = one tree walk per launch, + // dominated by host scalar reads through `SNodeRwAccessorsBank` on `FieldLoad` / `ExternalTensorRead` + // nodes (CPU: a memory load; CUDA / AMDGPU: a 4-8 byte DtoH). The evaluator returns -1 when the tree + // references state that is not host-resolvable from `ctx`; in that case we leave `effective_rows` + // unclipped from this source. + const int64_t evaluated = evaluate_adstack_size_expr(bound_expr.loop_iter_size_expr, prog, ctx); + if (evaluated > 0 && static_cast(evaluated) <= dispatched_threads_ceiling) { + effective_rows = std::min(effective_rows, static_cast(evaluated)); + } +} + +void bump_writes_for_kernel_llvm(Program *prog, + LaunchContextBuilder *ctx, + const std::vector &offloaded_tasks) { + if (prog == nullptr) { + return; + } + auto bump_data_ptr = [&](int arg_id) { + ArgArrayPtrKey data_key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; + auto it = ctx->array_ptrs.find(data_key); + if (it != ctx->array_ptrs.end() && it->second != nullptr) { + prog->adstack_cache().bump_ndarray_data_gen(it->second); + } + }; + for (const auto &task : offloaded_tasks) { + for (int snode_id : task.snode_writes) { + prog->adstack_cache().bump_snode_write_gen(snode_id); + } + for (int arg_id : task.arr_writes) { + bump_data_ptr(arg_id); + } + // Read-only `DevAllocType::kNone` args also need a bump: the user's host array is either H2D-blitted to a + // temporary device buffer (CUDA / AMDGPU) or read directly (CPU), and in both cases the data pointer used as + // the cache key is stable across launches, so a content mutation the user performed outside Quadrants's + // tracking is invisible to the metadata cache without an explicit bump. Mirrors the SPIR-V `kone_h2d_blit` + // rule in `bump_writes_for_kernel_spirv`. + for (int arg_id : task.arr_reads) { + auto type_it = ctx->device_allocation_type.find(arg_id); + if (type_it == ctx->device_allocation_type.end() || + type_it->second != LaunchContextBuilder::DevAllocType::kNone) { + continue; + } + bump_data_ptr(arg_id); + } + } +} + +void bump_writes_for_kernel_llvm(Program *prog, + LaunchContextBuilder *ctx, + const std::vector> &snode_writes_per_task, + const std::vector> &arr_writes_per_task, + const std::vector> &arr_reads_per_task) { + if (prog == nullptr) { + return; + } + auto bump_data_ptr = [&](int arg_id) { + ArgArrayPtrKey data_key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; + auto it = ctx->array_ptrs.find(data_key); + if (it != ctx->array_ptrs.end() && it->second != nullptr) { + prog->adstack_cache().bump_ndarray_data_gen(it->second); + } + }; + for (const auto &task_snodes : snode_writes_per_task) { + for (int snode_id : task_snodes) { + prog->adstack_cache().bump_snode_write_gen(snode_id); + } + } + for (const auto &task_args : arr_writes_per_task) { + for (int arg_id : task_args) { + bump_data_ptr(arg_id); + } + } + // Read-only `DevAllocType::kNone` args: see the comment in the CUDA / AMDGPU overload for why CPU LLVM also + // needs the bump. Empty `arr_reads_per_task` is the legal cache-miss path (offline-cache load that did not + // capture per-task arr_reads); skip the loop without raising. + for (const auto &task_args : arr_reads_per_task) { + for (int arg_id : task_args) { + auto type_it = ctx->device_allocation_type.find(arg_id); + if (type_it == ctx->device_allocation_type.end() || + type_it->second != LaunchContextBuilder::DevAllocType::kNone) { + continue; + } + bump_data_ptr(arg_id); + } + } +} + +void bump_writes_for_kernel_spirv( + Program *prog, + LaunchContextBuilder *ctx, + const std::vector &task_attribs, + const std::vector, irpass::ExternalPtrAccess>> &arr_access) { + if (prog == nullptr) { + return; + } + for (const auto &task : task_attribs) { + for (int snode_id : task.snode_writes) { + prog->adstack_cache().bump_snode_write_gen(snode_id); + } + } + for (const auto &kv : arr_access) { + const std::vector &indices = kv.first; + uint32_t access = uint32_t(kv.second); + QD_ASSERT(indices.size() == 1); + int arg_id = indices[0]; + bool kernel_writes = (access & uint32_t(irpass::ExternalPtrAccess::WRITE)) != 0; + bool kone_h2d_blit = (access & uint32_t(irpass::ExternalPtrAccess::READ)) != 0 && + ctx->device_allocation_type[arg_id] == LaunchContextBuilder::DevAllocType::kNone; + if (!kernel_writes && !kone_h2d_blit) { + continue; + } + ArgArrayPtrKey data_key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; + auto it = ctx->array_ptrs.find(data_key); + if (it != ctx->array_ptrs.end()) { + prog->adstack_cache().bump_ndarray_data_gen(it->second); + } + } +} + +} // namespace quadrants::lang diff --git a/quadrants/program/adstack/write_gen.h b/quadrants/program/adstack/write_gen.h new file mode 100644 index 0000000000..e66742d46d --- /dev/null +++ b/quadrants/program/adstack/write_gen.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include +#include + +#include "quadrants/codegen/llvm/llvm_compiled_data.h" +#include "quadrants/codegen/spirv/kernel_utils.h" +#include "quadrants/transforms/static_adstack_analysis.h" + +namespace quadrants::lang { + +class LaunchContextBuilder; +class Program; + +// Apply the captured per-task loop trip-count clip to `effective_rows`. Each loop iteration of an adstack +// task claims at most one row at the LCA-block, so the heap needs at most `trip_count` rows regardless of +// how many cells of an oversized gating SNode/ndarray the reducer counted. Two trip-count sources, picked +// in order: `bound_expr.loop_iter_static` (compile-time-known constant, integer compare) and +// `bound_expr.loop_iter_size_expr` (per-launch tree walk via `evaluate_adstack_size_expr`). Both are +// gated by `dispatched_threads_ceiling` so a `dynamic_gpu_range_for` that exceeds the dispatch cap and +// serialises iterations across threads (each thread reaches the LCA-block multiple times) does not +// accidentally undersize the heap; pass `std::numeric_limits::max()` to disable the +// ceiling. No-op when the static field is zero AND the SizeExpr is empty (the analyzer leaves both +// unset for shapes the compile-time path cannot cover) - the caller's pre-clip `effective_rows` is left +// unchanged so the runtime falls through to the unclipped reducer count. +void clip_effective_rows_by_loop_trip_count(std::size_t &effective_rows, + const StaticAdStackBoundExpr &bound_expr, + std::size_t dispatched_threads_ceiling, + Program *prog, + LaunchContextBuilder *ctx); + +// Adstack-cache invalidation bump. Called from each backend's kernel launcher BEFORE the per-task +// `publish_adstack_metadata` loop runs, so the per-task metadata cache (`Program::*PerTaskAdStackCacheEntry`) snapshots +// the latest counters at record time and the next lookup detects any drift. Two sources contribute: +// +// - SNode writes: every task in the kernel lists its compile-time `snode_writes` set (computed at codegen via +// `irpass::analysis::gather_snode_read_writes`), bumped per id; covers `SizeExpr::FieldLoad` cache invalidation. +// - ndarray data writes: every arg slot the kernel writes to (`OffloadedTask::arr_writes` on LLVM-GPU, the kernel- +// level `ctx_attribs.arr_access` WRITE bits on SPIR-V) bumps the bound `DeviceAllocation`'s data generation. +// SPIR-V also bumps on the `kNone` READ branch to catch host-driven mutations of raw numpy / torch buffers blitted +// between launches; covers `SizeExpr::ExternalTensorRead` invalidation. +// +// The two helpers share the same Program-level effect; their signatures differ only because the codegen-time write +// sets are stored in different per-backend structs. Forward-only kernels (no adstack tasks) still call these to keep +// counters monotone, which is cheap (one map insert per snode_id at most). +void bump_writes_for_kernel_llvm(Program *prog, + LaunchContextBuilder *ctx, + const std::vector &offloaded_tasks); +// CPU launcher overload: per-task snode_writes / arr_writes / arr_reads are stored as separate parallel vectors on +// the launcher `Context` rather than as `OffloadedTask` clones, for legacy reasons documented in the CPU `Context` +// struct. +void bump_writes_for_kernel_llvm(Program *prog, + LaunchContextBuilder *ctx, + const std::vector> &snode_writes_per_task, + const std::vector> &arr_writes_per_task, + const std::vector> &arr_reads_per_task); +void bump_writes_for_kernel_spirv( + Program *prog, + LaunchContextBuilder *ctx, + const std::vector &task_attribs, + const std::vector, irpass::ExternalPtrAccess>> &arr_access); + +} // namespace quadrants::lang diff --git a/quadrants/program/adstack_size_expr_eval.cpp b/quadrants/program/adstack_size_expr_eval.cpp deleted file mode 100644 index 068331d515..0000000000 --- a/quadrants/program/adstack_size_expr_eval.cpp +++ /dev/null @@ -1,1984 +0,0 @@ -#include "quadrants/program/adstack_size_expr_eval.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "quadrants/codegen/llvm/llvm_compiled_data.h" -#include "quadrants/codegen/spirv/adstack_sizer_shader.h" -#include "quadrants/common/logging.h" -#include "quadrants/ir/adstack_size_expr_device.h" -#include "quadrants/ir/snode.h" -#include "quadrants/ir/type.h" -#include "quadrants/ir/type_factory.h" -#include "quadrants/ir/type_utils.h" -#include "quadrants/program/launch_context_builder.h" -#include "quadrants/program/program.h" -#include "quadrants/program/snode_rw_accessors_bank.h" -#include "quadrants/rhi/device.h" - -namespace quadrants::lang { - -namespace { - -using ReadSink = std::vector; - -// Forward-declared, defined further down. Reads SNode `snode_id` at `indices` via the per-launch read -// cache (when active) so multiple size-expr trees evaluated within the same outer launch share a single -// reader-kernel dispatch per `(snode_id, indices)` pair. -int64_t read_field_with_launch_cache(int snode_id, const std::vector &indices, Program *prog); - -int64_t evaluate_node(const SerializedSizeExpr &expr, - int32_t node_idx, - std::unordered_map &bound_vars, - Program *prog, - LaunchContextBuilder *ctx, - ReadSink *reads); - -int64_t evaluate_field_load(const SerializedSizeExprNode &node, - std::unordered_map &bound_vars, - Program *prog, - ReadSink *reads) { - QD_ASSERT_INFO(node.snode_id >= 0, "SerializedSizeExpr FieldLoad with no snode_id"); - SNode *snode = prog->get_snode_by_id(node.snode_id); - QD_ASSERT_INFO(snode != nullptr, - "SerializedSizeExpr FieldLoad snode_id={} not found in the current program's snode trees", - node.snode_id); - std::vector indices; - indices.reserve(node.indices.size()); - for (int32_t raw : node.indices) { - if (raw >= 0) { - indices.push_back(raw); - } else { - int32_t var_id = -(raw + 1); - auto it = bound_vars.find(var_id); - QD_ASSERT_INFO(it != bound_vars.end(), - "SerializedSizeExpr FieldLoad references unbound var_id={} (the enclosing MaxOverRange " - "node must have bound it before this read)", - var_id); - indices.push_back(static_cast(it->second)); - } - } - int64_t v = read_field_with_launch_cache(node.snode_id, indices, prog); - if (reads != nullptr) { - AdStackCache::SizeExprReadObservation obs; - obs.kind = AdStackCache::SizeExprReadObservation::FieldLoadObs; - obs.snode_id = node.snode_id; - obs.indices = std::move(indices); - obs.arg_shape_axis = 0; - obs.prim_dt = 0; - obs.observed_value = v; - // Snapshot the SNode's write gen so the next replay can fast-skip when no kernel has written this SNode - // since record time (the dominant case for a steady-state reverse-mode loop with stable bounds). - obs.observed_gen = prog->adstack_cache().snode_write_gen(node.snode_id); - reads->push_back(std::move(obs)); - } - return v; -} - -int64_t evaluate_external_tensor_read(const SerializedSizeExprNode &node, - std::unordered_map &bound_vars, - Program *prog, - LaunchContextBuilder *ctx, - ReadSink *reads) { - QD_ASSERT_INFO(ctx != nullptr, - "SerializedSizeExpr ExternalTensorRead evaluated with no LaunchContextBuilder; the launcher " - "must pass the current launch's context in"); - QD_ASSERT_INFO(!node.arg_id_path.empty(), "SerializedSizeExpr ExternalTensorRead has empty arg_id_path"); - int arg_id = node.arg_id_path[0]; - ArgArrayPtrKey key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; - auto it = ctx->array_ptrs.find(key); - QD_ASSERT_INFO(it != ctx->array_ptrs.end(), - "SerializedSizeExpr ExternalTensorRead: arg {} has no data pointer in launch context", arg_id); - void *data_ptr = it->second; - // Resolve each index (possibly via a bound variable) and compose them into the C-order linear offset - // `sum_i(idx_i * prod_{j>i}(shape_j))`. Multi-dim shapes are read from the launch context through the same - // `SHAPE_POS_IN_NDARRAY` path `ExternalTensorShape` uses, so an ndarray indexed by two or more loop variables lowers - // to the correct element rather than the stride-1 sum `arr_flat[i + j + ...]`. Mirrors the per-axis stride that - // `encode_subtree` precomputes on the SPIR-V path; on CPU the host evaluator is called directly from - // `publish_adstack_metadata`, so the stride math has to live here too. - std::vector resolved(node.indices.size()); - for (std::size_t i = 0; i < node.indices.size(); ++i) { - int32_t raw = node.indices[i]; - if (raw >= 0) { - resolved[i] = raw; - } else { - int32_t var_id = -(raw + 1); - auto bv = bound_vars.find(var_id); - QD_ASSERT_INFO(bv != bound_vars.end(), "SerializedSizeExpr ExternalTensorRead references unbound var_id={}", - var_id); - resolved[i] = bv->second; - } - } - int64_t linear = 0; - int64_t stride = 1; - for (std::size_t i = node.indices.size(); i > 0; --i) { - linear += resolved[i - 1] * stride; - if (i - 1 > 0) { - std::vector sh_idx(node.arg_id_path.begin(), node.arg_id_path.end()); - sh_idx.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY); - sh_idx.push_back(static_cast(i - 1)); - // Ndarray shapes are `int32` in the args struct (same convention `evaluate_external_tensor_shape` relies on); - // reading as `int64` would sign-extend the adjacent slot into the shape and produce garbage strides. - stride *= static_cast(ctx->get_struct_arg_host(sh_idx)); - } - } - auto prim_dt = static_cast(node.const_value); - int64_t v; - switch (prim_dt) { - case PrimitiveTypeID::i32: - v = static_cast(static_cast(data_ptr)[linear]); - break; - case PrimitiveTypeID::i64: - v = static_cast(data_ptr)[linear]; - break; - case PrimitiveTypeID::u32: - v = static_cast(static_cast(data_ptr)[linear]); - break; - case PrimitiveTypeID::u64: - v = static_cast(static_cast(data_ptr)[linear]); - break; - case PrimitiveTypeID::i16: - v = static_cast(static_cast(data_ptr)[linear]); - break; - case PrimitiveTypeID::u16: - v = static_cast(static_cast(data_ptr)[linear]); - break; - case PrimitiveTypeID::i8: - v = static_cast(static_cast(data_ptr)[linear]); - break; - case PrimitiveTypeID::u8: - v = static_cast(static_cast(data_ptr)[linear]); - break; - default: - QD_ERROR("SerializedSizeExpr ExternalTensorRead: unsupported element type {}", node.const_value); - v = 0; - } - if (reads != nullptr) { - AdStackCache::SizeExprReadObservation obs; - obs.kind = AdStackCache::SizeExprReadObservation::ExternalReadObs; - obs.snode_id = 0; - obs.indices.reserve(resolved.size()); - for (auto r : resolved) - obs.indices.push_back(static_cast(r)); - obs.arg_id_path = node.arg_id_path; - obs.arg_shape_axis = 0; - obs.prim_dt = static_cast(prim_dt); - obs.observed_value = v; - obs.observed_devalloc = data_ptr; - if (prog != nullptr) { - // Snapshot the ndarray's data gen so the next replay can fast-skip when no kernel / Ndarray API write - // has touched the underlying buffer since record time. Mirrors the FieldLoad fast-skip; covers the same - // steady-state hot path for ndarray-bounded reverse-mode loops. - obs.observed_gen = prog->adstack_cache().ndarray_data_gen(data_ptr); - } - reads->push_back(std::move(obs)); - } - return v; -} - -int64_t evaluate_external_tensor_shape(const SerializedSizeExprNode &node, LaunchContextBuilder *ctx, ReadSink *reads) { - QD_ASSERT_INFO(ctx != nullptr, - "SerializedSizeExpr ExternalTensorShape evaluated with no LaunchContextBuilder; the launcher " - "must pass the current launch's context into the evaluator to resolve ndarray shapes"); - std::vector arg_indices(node.arg_id_path.begin(), node.arg_id_path.end()); - arg_indices.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY); - arg_indices.push_back(node.arg_shape_axis); - // Ndarray shape slots are `int32` in the args struct (same convention `evaluate_external_tensor_read` relies - // on for its stride multiplies). Using `int64` here reads 8 bytes past the slot and sign-extends the next - // field into the shape, so a user-visible downstream effect is that any `SizeExpr` node that feeds a - // shape-derived value into a trip count (e.g. `MaxOverRange(0, ExtShape, ...)`) evaluates its range as - // garbage - often zero when the adjacent field is zero-initialised - and the containing tree collapses to - // zero. The adstack max_size is clamped to 1 on a zero tree result, which under-bounds real push counts and - // trips an overflow assertion at the next `qd.sync()`. - int64_t v = static_cast(ctx->get_struct_arg_host(arg_indices)); - if (reads != nullptr) { - AdStackCache::SizeExprReadObservation obs; - obs.kind = AdStackCache::SizeExprReadObservation::ExternalShapeObs; - obs.snode_id = 0; - obs.arg_id_path = node.arg_id_path; - obs.arg_shape_axis = node.arg_shape_axis; - obs.prim_dt = 0; - obs.observed_value = v; - reads->push_back(std::move(obs)); - } - return v; -} - -int64_t evaluate_node(const SerializedSizeExpr &expr, - int32_t node_idx, - std::unordered_map &bound_vars, - Program *prog, - LaunchContextBuilder *ctx, - ReadSink *reads) { - QD_ASSERT_INFO(node_idx >= 0 && static_cast(node_idx) < expr.nodes.size(), - "SerializedSizeExpr node_idx {} out of bounds (size={})", node_idx, expr.nodes.size()); - const auto &node = expr.nodes[node_idx]; - switch (static_cast(node.kind)) { - case SizeExpr::Kind::Const: - return node.const_value; - case SizeExpr::Kind::FieldLoad: - return evaluate_field_load(node, bound_vars, prog, reads); - case SizeExpr::Kind::Add: - return evaluate_node(expr, node.operand_a, bound_vars, prog, ctx, reads) + - evaluate_node(expr, node.operand_b, bound_vars, prog, ctx, reads); - case SizeExpr::Kind::Sub: - return std::max(evaluate_node(expr, node.operand_a, bound_vars, prog, ctx, reads) - - evaluate_node(expr, node.operand_b, bound_vars, prog, ctx, reads), - 0); - case SizeExpr::Kind::Mul: - return evaluate_node(expr, node.operand_a, bound_vars, prog, ctx, reads) * - evaluate_node(expr, node.operand_b, bound_vars, prog, ctx, reads); - case SizeExpr::Kind::Max: - return std::max(evaluate_node(expr, node.operand_a, bound_vars, prog, ctx, reads), - evaluate_node(expr, node.operand_b, bound_vars, prog, ctx, reads)); - case SizeExpr::Kind::MaxOverRange: { - int64_t begin = evaluate_node(expr, node.operand_a, bound_vars, prog, ctx, reads); - int64_t end = evaluate_node(expr, node.operand_b, bound_vars, prog, ctx, reads); - // Guard against pathological trip counts. The evaluator walks `[begin, end)` linearly and re-evaluates the - // body at every i; a range of several million would stall the launch hot path for seconds. Real reverse-mode - // trip counts sit well below this cap (a few hundred to a few thousand in practice); anything above is - // almost certainly a pre-pass grammar bug the user should file, and a clear QD_ERROR beats a silent hang. - constexpr int64_t kMaxOverRangeIterations = int64_t{1} << 24; - QD_ERROR_IF(end > begin && end - begin > kMaxOverRangeIterations, - "SerializedSizeExpr MaxOverRange iteration count {} exceeds the {} guard; refusing to enumerate. " - "Shrink the enclosing reverse-mode loop or restructure the `SizeExpr` source kernel.", - end - begin, kMaxOverRangeIterations); - int64_t result = 0; - // Bind `var_id` in `bound_vars` for the duration of the loop and restore the outer-scope value (or erase, if - // there was none) before returning, so nested `MaxOverRange` bindings of the same `var_id` stay correct without - // cloning the entire map per iteration. - auto prev_it = bound_vars.find(node.var_id); - bool had_prev = prev_it != bound_vars.end(); - int64_t prev_val = had_prev ? prev_it->second : 0; - for (int64_t i = begin; i < end; ++i) { - bound_vars[node.var_id] = i; - int64_t v = evaluate_node(expr, node.body_node_idx, bound_vars, prog, ctx, reads); - if (v > result) { - result = v; - } - } - if (had_prev) { - bound_vars[node.var_id] = prev_val; - } else { - bound_vars.erase(node.var_id); - } - return result; - } - case SizeExpr::Kind::BoundVariable: { - auto it = bound_vars.find(node.var_id); - QD_ASSERT_INFO(it != bound_vars.end(), - "SerializedSizeExpr BoundVariable var_id={} evaluated outside its MaxOverRange scope", - node.var_id); - return it->second; - } - case SizeExpr::Kind::ExternalTensorShape: - return evaluate_external_tensor_shape(node, ctx, reads); - case SizeExpr::Kind::ExternalTensorRead: - return evaluate_external_tensor_read(node, bound_vars, prog, ctx, reads); - } - QD_ERROR("unreachable SerializedSizeExpr kind {}", node.kind); - return 0; -} - -// -------------------------------------------------------------------------------------------------------------- -// Device-bytecode encoder helpers -// -------------------------------------------------------------------------------------------------------------- - -// `contains_device_leaf[i]` is true when subtree rooted at node `i` has at least one leaf that MUST stay on the -// device during encoding (the host fold cannot substitute it with a `Const`). On the LLVM path this is any -// `ExternalTensorRead` leaf - `FieldLoad` can be host-folded via `SNodeRwAccessorsBank::read_int`, which is safe -// on CPU / CUDA / AMDGPU. On the SPIR-V path the caller flips `fieldload_stays_on_device` to true because on -// MoltenVK a nested `read_int` submit crashes inside the descriptor-set bind path; keeping `FieldLoad` on the -// device side (via PSB loads in the sizer shader) avoids that entirely. Computed bottom-up; `SerializedSizeExpr` -// is already in post-order so every operand / body index is < i. -std::vector compute_contains_device_leaf(const SerializedSizeExpr &expr, bool fieldload_stays_on_device) { - std::vector result(expr.nodes.size(), false); - for (std::size_t i = 0; i < expr.nodes.size(); ++i) { - const auto &node = expr.nodes[i]; - auto kind = static_cast(node.kind); - bool hit = (kind == SizeExpr::Kind::ExternalTensorRead) || - (fieldload_stays_on_device && kind == SizeExpr::Kind::FieldLoad); - if (!hit && node.operand_a >= 0) - hit = result[node.operand_a]; - if (!hit && node.operand_b >= 0) - hit = result[node.operand_b]; - if (!hit && node.body_node_idx >= 0) - hit = result[node.body_node_idx]; - result[i] = hit; - } - return result; -} - -// `free_vars[i]` is the set of `BoundVariable::var_id`s referenced inside subtree(i) but NOT bound by any -// `MaxOverRange` inside that same subtree. An empty set means the subtree is closed and can be evaluated on the -// host without an outer-iteration context. `FieldLoad` / `ExternalTensorRead` index slots use the same -// `-(var_id + 1)` encoding as `BoundVariable` and are accounted for here. -std::vector> compute_free_vars(const SerializedSizeExpr &expr) { - std::vector> result(expr.nodes.size()); - for (std::size_t i = 0; i < expr.nodes.size(); ++i) { - const auto &node = expr.nodes[i]; - auto &fv = result[i]; - auto collect_idx_vars = [&](const std::vector &indices) { - for (int32_t raw : indices) { - if (raw < 0) - fv.insert(-(raw + 1)); - } - }; - switch (static_cast(node.kind)) { - case SizeExpr::Kind::Const: - case SizeExpr::Kind::ExternalTensorShape: - break; - case SizeExpr::Kind::BoundVariable: - fv.insert(node.var_id); - break; - case SizeExpr::Kind::FieldLoad: - case SizeExpr::Kind::ExternalTensorRead: - collect_idx_vars(node.indices); - break; - case SizeExpr::Kind::Add: - case SizeExpr::Kind::Sub: - case SizeExpr::Kind::Mul: - case SizeExpr::Kind::Max: - fv = result[node.operand_a]; - for (auto v : result[node.operand_b]) - fv.insert(v); - break; - case SizeExpr::Kind::MaxOverRange: { - fv = result[node.operand_a]; - for (auto v : result[node.operand_b]) - fv.insert(v); - // MaxOverRange binds `var_id` for its body only: body's free vars minus this binding add into the - // outer set. - for (auto v : result[node.body_node_idx]) { - if (v != node.var_id) - fv.insert(v); - } - break; - } - } - } - return result; -} - -// Walks `expr` and builds a dense `original_var_id -> [0, N)` map across every `var_id` the tree references -// (`MaxOverRange` binds, `BoundVariable` leaves, and bound-var entries inside each ETR / FieldLoad index list). -// The walker preserves encounter order so nested `MaxOverRange` binds keep monotonically increasing dense ids, -// which also matches the natural `values[]` indexing the device interpreter does at each bind. Hard-errors if -// the tree references more distinct bound vars than the device interpreter's per-stack scope capacity. -std::unordered_map build_dense_var_id_remap(const SerializedSizeExpr &expr) { - std::unordered_map remap; - auto add = [&](int32_t v) { - if (v < 0) - return; - if (remap.find(v) == remap.end()) { - int32_t dense = static_cast(remap.size()); - remap.emplace(v, dense); - } - }; - for (const auto &node : expr.nodes) { - const auto kind = static_cast(node.kind); - if (kind == SizeExpr::Kind::MaxOverRange || kind == SizeExpr::Kind::BoundVariable) - add(node.var_id); - for (int32_t raw : node.indices) { - if (raw < 0) - add(-(raw + 1)); - } - } - QD_ERROR_IF(static_cast(remap.size()) > kAdStackSizeExprDeviceMaxBoundVars, - "Adstack SizeExpr tree references {} distinct bound variable ids, which exceeds the device " - "interpreter's per-stack scope capacity ({}). This almost always indicates a deeply nested " - "reverse-mode loop shape that the pre-pass should have folded earlier; shrink the enclosing " - "loops or file a bug so the grammar / walker can be tightened.", - remap.size(), kAdStackSizeExprDeviceMaxBoundVars); - return remap; -} - -// Computes the maximum `MaxOverRange` nesting depth reachable from any root in `expr`, i.e. the deepest -// chain of `MaxOverRange` nodes whose `body_node_idx` recursively references another `MaxOverRange`. The -// sizer shader's per-invocation pending-frame stack is sized to `kAdStackSizerMaxPendingFrames`; the encoder -// hard-errors when a tree's nesting exceeds this so the shader's fixed-size access-chain stays in bounds -// without a runtime guard. Each node's depth is memoised to keep the walk linear in `expr.nodes.size()`. -int32_t compute_max_mor_nesting(const SerializedSizeExpr &expr) { - std::vector depth(expr.nodes.size(), -1); - std::function visit = [&](int32_t i) -> int32_t { - if (i < 0 || static_cast(i) >= expr.nodes.size()) - return 0; - if (depth[i] >= 0) - return depth[i]; - const auto &n = expr.nodes[i]; - int32_t child_max = 0; - for (int32_t c : {n.operand_a, n.operand_b, n.body_node_idx}) { - if (c >= 0) - child_max = std::max(child_max, visit(c)); - } - int32_t self = static_cast(n.kind) == SizeExpr::Kind::MaxOverRange ? 1 : 0; - depth[i] = self + child_max; - return depth[i]; - }; - int32_t max_depth = 0; - for (std::size_t i = 0; i < expr.nodes.size(); ++i) { - max_depth = std::max(max_depth, visit(static_cast(i))); - } - return max_depth; -} - -// Returns the dense id for `original_var_id`, or fires a hard error if the remap lost track of it (which would -// indicate a walker divergence between `build_dense_var_id_remap` and `encode_subtree`). -int32_t remap_var_id(const std::unordered_map &remap, int32_t original) { - auto it = remap.find(original); - QD_ASSERT_INFO(it != remap.end(), - "Adstack SizeExpr encoder saw var_id={} not present in the dense remap; this " - "is a walker bug between `build_dense_var_id_remap` and `encode_subtree`.", - original); - return it->second; -} - -// Initialises a fresh device node with every unused slot sentinelled so the interpreter can tell them apart from -// legitimate zero-valued slots (e.g. `operand_a == 0` is a valid node index; only `-1` signals "unused"). -AdStackSizeExprDeviceNode make_empty_device_node(int32_t kind) { - AdStackSizeExprDeviceNode dn{}; - dn.kind = kind; - dn.operand_a = -1; - dn.operand_b = -1; - dn.body_node_idx = -1; - dn.var_id = -1; - dn.prim_dt = -1; - dn.arg_buffer_offset = -1; - dn.indices_offset = 0; - dn.indices_count = 0; - dn._pad0 = 0; - dn.const_value = 0; - return dn; -} - -// Data needed to encode a `FieldLoad` as a `kFieldLoad` device node. Populated by the SPIR-V encoder entry -// point via `GfxRuntime` / `Device` queries; the LLVM encoder passes a default-constructed (empty) emitter, -// which routes every `FieldLoad` through the host-fold path instead (safe on CPU / CUDA / AMDGPU where a -// nested accessor kernel launch is fine). -struct FieldLoadDeviceEmitter { - // Returns true on success, populating `out_base_psb` with `root_buffer_psb + place_byte_offset_in_root` and - // `out_elem_strides` with one positive int32 *element* stride per active axis of `snode` (stride in units of - // the leaf's primitive type, not bytes - the sizer shader reuses `psb_load_scalar` which already multiplies - // by `sizeof(prim_dt)`). Returns false when the snode layout is not amenable to direct PSB indexing - // (bitmasked / pointer / hash chain, bit-level place, not-all-dense path), in which case the encoder raises - // a `QD_ERROR`. The dense-only restriction is deliberate - observed kernels exercise only dense chains in the - // adstack pre-pass's `SizeExpr::FieldLoad` leaves, and extending this to bitmasked / pointer would require - // threading the full access codegen through the sizer shader, which is out of scope. - std::function *out_elem_strides)> fetch; - - bool empty() const { - return fetch == nullptr; - } -}; - -// Recursive top-down encoder. Each call returns the index of the emitted root in `out_nodes`. Subtrees whose -// leaves are all host-resolvable (no `ExternalTensorRead`, and - on the LLVM path - no `FieldLoad` either) and -// whose bound variables are all locally bound within the subtree get folded to a single `kConst` device node -// by running `evaluate_node` over them. On the SPIR-V path, `FieldLoad` also survives as a `kFieldLoad` device -// node alongside `kExternalTensorRead`, so the shader can resolve the snode read in place via PSB. -int32_t encode_subtree(const SerializedSizeExpr &src, - int32_t src_idx, - const std::vector &contains_device_leaf, - const std::vector> &free_vars, - const std::unordered_map &var_id_remap, - Program *prog, - LaunchContextBuilder *ctx, - const FieldLoadDeviceEmitter &fl_emitter, - std::vector &out_nodes, - std::vector &out_indices, - ReadSink *reads) { - QD_ASSERT_INFO(src_idx >= 0 && static_cast(src_idx) < src.nodes.size(), - "encode_subtree: src_idx {} out of bounds (size={})", src_idx, src.nodes.size()); - const bool subtree_needs_device = contains_device_leaf[src_idx]; - const bool subtree_closed = free_vars[src_idx].empty(); - - if (!subtree_needs_device && subtree_closed) { - // Whole subtree resolves without any device-resident read and without an outer-iteration context, so fold it - // to a single `Const` by running the host evaluator over it. This is the only path that can substitute - // `FieldLoad` / `ExternalTensorShape` leaves - the device interpreter does not know how to walk SNodes or - // index into `args_type`. - std::unordered_map empty_bound; - int64_t val = evaluate_node(src, src_idx, empty_bound, prog, ctx, reads); - AdStackSizeExprDeviceNode dn = make_empty_device_node(static_cast(AdStackSizeExprDeviceKind::kConst)); - dn.const_value = val; - out_nodes.push_back(dn); - return static_cast(out_nodes.size() - 1); - } - - const auto &node = src.nodes[src_idx]; - const auto kind = static_cast(node.kind); - switch (kind) { - case SizeExpr::Kind::Const: { - AdStackSizeExprDeviceNode dn = make_empty_device_node(static_cast(AdStackSizeExprDeviceKind::kConst)); - dn.const_value = node.const_value; - out_nodes.push_back(dn); - return static_cast(out_nodes.size() - 1); - } - case SizeExpr::Kind::BoundVariable: { - AdStackSizeExprDeviceNode dn = - make_empty_device_node(static_cast(AdStackSizeExprDeviceKind::kBoundVariable)); - dn.var_id = remap_var_id(var_id_remap, node.var_id); - out_nodes.push_back(dn); - return static_cast(out_nodes.size() - 1); - } - case SizeExpr::Kind::Add: - case SizeExpr::Kind::Sub: - case SizeExpr::Kind::Mul: - case SizeExpr::Kind::Max: { - int32_t a = encode_subtree(src, node.operand_a, contains_device_leaf, free_vars, var_id_remap, prog, ctx, - fl_emitter, out_nodes, out_indices, reads); - int32_t b = encode_subtree(src, node.operand_b, contains_device_leaf, free_vars, var_id_remap, prog, ctx, - fl_emitter, out_nodes, out_indices, reads); - AdStackSizeExprDeviceKind dk = AdStackSizeExprDeviceKind::kAdd; - if (kind == SizeExpr::Kind::Sub) - dk = AdStackSizeExprDeviceKind::kSub; - else if (kind == SizeExpr::Kind::Mul) - dk = AdStackSizeExprDeviceKind::kMul; - else if (kind == SizeExpr::Kind::Max) - dk = AdStackSizeExprDeviceKind::kMax; - AdStackSizeExprDeviceNode dn = make_empty_device_node(static_cast(dk)); - dn.operand_a = a; - dn.operand_b = b; - out_nodes.push_back(dn); - return static_cast(out_nodes.size() - 1); - } - case SizeExpr::Kind::MaxOverRange: { - int32_t a = encode_subtree(src, node.operand_a, contains_device_leaf, free_vars, var_id_remap, prog, ctx, - fl_emitter, out_nodes, out_indices, reads); - int32_t b = encode_subtree(src, node.operand_b, contains_device_leaf, free_vars, var_id_remap, prog, ctx, - fl_emitter, out_nodes, out_indices, reads); - int32_t body = encode_subtree(src, node.body_node_idx, contains_device_leaf, free_vars, var_id_remap, prog, ctx, - fl_emitter, out_nodes, out_indices, reads); - AdStackSizeExprDeviceNode dn = - make_empty_device_node(static_cast(AdStackSizeExprDeviceKind::kMaxOverRange)); - dn.operand_a = a; - dn.operand_b = b; - dn.body_node_idx = body; - dn.var_id = remap_var_id(var_id_remap, node.var_id); - out_nodes.push_back(dn); - return static_cast(out_nodes.size() - 1); - } - case SizeExpr::Kind::ExternalTensorRead: { - QD_ASSERT_INFO(ctx != nullptr && ctx->args_type != nullptr, - "encode_subtree: ExternalTensorRead at node {} requires a LaunchContextBuilder with a valid " - "args_type to precompute the data_ptr offset", - src_idx); - QD_ASSERT_INFO(!node.arg_id_path.empty(), "ExternalTensorRead at node {} has empty arg_id_path", src_idx); - std::vector arg_indices(node.arg_id_path.begin(), node.arg_id_path.end()); - arg_indices.push_back(TypeFactory::DATA_PTR_POS_IN_NDARRAY); - const size_t data_ptr_offset = ctx->args_type->get_element_offset(arg_indices); - AdStackSizeExprDeviceNode dn = - make_empty_device_node(static_cast(AdStackSizeExprDeviceKind::kExternalTensorRead)); - // Cast to i32 is safe: `arg_buffer` sizes in practice are kilobytes, well under INT32_MAX. - dn.arg_buffer_offset = static_cast(data_ptr_offset); - dn.prim_dt = static_cast(node.const_value); // the pre-pass stashes `PrimitiveTypeID` in const_value - dn.indices_offset = static_cast(out_indices.size()); - dn.indices_count = static_cast(node.indices.size()); - // Pre-compute per-axis element strides in C order (`stride[k] = prod_{m > k} shape[m]`). Shapes live in - // the kernel args struct as `int32` slots at the `SHAPE_POS_IN_NDARRAY` path, same source the host - // evaluator reads; using the live launch context keeps strides consistent with whichever ndarray the - // user handed to the kernel on this launch. Emit as `[idx_a_raw, elem_stride_a]` pairs per axis, - // matching the `kFieldLoad` layout so the device interpreter and SPIR-V sizer shader can share one - // pair-walking offset-computation loop instead of carrying a separate stride-1 path. - std::vector elem_strides(node.indices.size(), 1); - if (node.indices.size() > 1) { - for (std::size_t k = node.indices.size(); k-- > 0;) { - if (k + 1 < node.indices.size()) { - std::vector sh_idx(node.arg_id_path.begin(), node.arg_id_path.end()); - sh_idx.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY); - sh_idx.push_back(static_cast(k + 1)); - int32_t sh = ctx->get_struct_arg_host(sh_idx); - elem_strides[k] = elem_strides[k + 1] * sh; - } - } - } - for (std::size_t k = 0; k < node.indices.size(); ++k) { - int32_t raw = node.indices[k]; - if (raw < 0) { - // Remap bound-variable refs so the device interpreter's `scope->values[var]` read lands in the - // `[0, kAdStackSizeExprDeviceMaxBoundVars)` range regardless of how large the source tree's - // `var_id_counter` grew across its push-site walks. - int32_t dense = remap_var_id(var_id_remap, -(raw + 1)); - raw = -(dense + 1); - } - out_indices.push_back(raw); - out_indices.push_back(elem_strides[k]); - } - out_nodes.push_back(dn); - return static_cast(out_nodes.size() - 1); - } - case SizeExpr::Kind::FieldLoad: { - // If we reach here the subtree is not host-substitutable (has free bound vars or sits alongside an - // `ExternalTensorRead` in the same closed context, or - on the SPIR-V path - `FieldLoad` is deliberately - // kept on the device via `fl_emitter`). Without an emitter, the LLVM path would have folded it earlier; - // reaching here without one means the shape is outside what the grammar supports, which is a user-facing - // bug, not a runtime fallback. - QD_ASSERT_INFO( - !fl_emitter.empty(), - "Adstack SizeExpr FieldLoad at node {} survived the host fold without a FieldLoadDeviceEmitter. The " - "LLVM encoder should route closed FieldLoads through `evaluate_node` and reject non-closed ones before " - "the structural pre-pass emits them; if this fires, a SerializedSizeExpr with a bound-var-indexed " - "FieldLoad leaf reached an LLVM-targeted encoder (which cannot resolve it on-device).", - src_idx); - QD_ASSERT_INFO(node.snode_id >= 0, "FieldLoad at node {} has no snode_id", src_idx); - QD_ASSERT_INFO(prog != nullptr, "encode_subtree: FieldLoad needs a live Program to resolve snode {}", - node.snode_id); - SNode *snode = prog->get_snode_by_id(node.snode_id); - QD_ASSERT_INFO(snode != nullptr, - "FieldLoad at node {} references snode_id={} which is not in the program's snode tree", src_idx, - node.snode_id); - uint64_t base_psb = 0; - std::vector elem_strides; - bool fetched = fl_emitter.fetch(snode, &base_psb, &elem_strides); - QD_ERROR_IF(!fetched, - "Adstack SizeExpr FieldLoad at node {} on snode_id={} could not be resolved for device-side " - "evaluation: the snode layout is not a pure-dense chain ending in a plain place leaf (bitmasked " - "/ pointer / bit-level snodes are not supported by the SPIR-V sizer shader). Rewrite the trip " - "count to use a dense field, or extend the shader to walk the non-dense hierarchy.", - src_idx, node.snode_id); - QD_ASSERT_INFO(elem_strides.size() == node.indices.size(), - "FieldLoad at node {}: elem_strides.size()={} must match node.indices.size()={} (one stride " - "per active axis)", - src_idx, elem_strides.size(), node.indices.size()); - AdStackSizeExprDeviceNode dn = - make_empty_device_node(static_cast(AdStackSizeExprDeviceKind::kFieldLoad)); - dn.const_value = static_cast(base_psb); - // `PrimitiveTypeID` for the leaf: mirrors ExternalTensorRead's field. The pre-pass emits a `FieldLoad` - // `SerializedSizeExprNode` with `snode_id` set and the element type implicit in the snode; we look it up - // here so the shader's existing `emit_psb_load_i64` switch (shared with ETR) can dispatch on it. - dn.prim_dt = static_cast(snode->dt->cast()->type); - dn.indices_offset = static_cast(out_indices.size()); - dn.indices_count = static_cast(node.indices.size()); - // Interleaved `[idx_a_raw, elem_stride_a]` pairs per axis. The shader reads 2 i32s per axis and - // accumulates `idx_a * stride_a` into the element index, then `psb_load_scalar` multiplies by the - // element size to get the final byte offset. Bound-variable refs (negative entries) are dense-remapped - // so the device interpreter's fixed-size `scope->values[]` stays in bounds. - for (std::size_t a = 0; a < node.indices.size(); ++a) { - int32_t raw = static_cast(node.indices[a]); - if (raw < 0) { - int32_t dense = remap_var_id(var_id_remap, -(raw + 1)); - raw = -(dense + 1); - } - out_indices.push_back(raw); - out_indices.push_back(elem_strides[a]); - } - out_nodes.push_back(dn); - return static_cast(out_nodes.size() - 1); - } - case SizeExpr::Kind::ExternalTensorShape: { - // Should have been folded to `Const` by the `subtree_needs_device == false && subtree_closed == true` - // branch above, since `ExternalTensorShape` has no free vars and cannot be an `ExternalTensorRead`. Hitting - // this branch is a bug in the encoder, not in the kernel. - QD_ERROR( - "Adstack SizeExpr ExternalTensorShape at node {} escaped host folding: this is an encoder invariant" - " violation - shape nodes are always closed and should have been emitted as Const.", - src_idx); - return -1; - } - } - QD_ERROR("encode_subtree: unreachable kind {} at node {}", node.kind, src_idx); - return -1; -} - -} // namespace - -namespace { - -// Per-launch cache of `FieldLoad` re-reads, keyed by `(snode_id, indices)`. Within one host-side eval root -// call the SNode field values are pinned (no other kernel runs concurrently), so deduping repeats across -// the size-expr trees evaluated in that window is correctness-safe. -struct LaunchScopedReadCache { - struct Key { - int snode_id; - std::vector indices; - bool operator==(const Key &o) const noexcept { - return snode_id == o.snode_id && indices == o.indices; - } - }; - struct KeyHash { - std::size_t operator()(const Key &k) const noexcept { - std::size_t h = std::hash{}(k.snode_id); - for (int v : k.indices) { - h ^= std::hash{}(v) + 0x9e3779b97f4a7c15ull + (h << 6) + (h >> 2); - } - return h; - } - }; - std::unordered_map map; -}; -thread_local LaunchScopedReadCache *t_launch_read_cache = nullptr; - -int64_t read_field_with_launch_cache(int snode_id, const std::vector &indices, Program *prog) { - SNode *snode = prog->get_snode_by_id(snode_id); - if (snode == nullptr) { - return std::numeric_limits::min(); - } - if (t_launch_read_cache != nullptr) { - LaunchScopedReadCache::Key key{snode_id, indices}; - auto it = t_launch_read_cache->map.find(key); - if (it != t_launch_read_cache->map.end()) { - return it->second; - } - int64_t v = prog->get_snode_rw_accessors_bank().get(snode).read_int(indices); - t_launch_read_cache->map.emplace(std::move(key), v); - return v; - } - return prog->get_snode_rw_accessors_bank().get(snode).read_int(indices); -} - -// Read the input that `obs` describes against the live state and `ctx`. Caller compares the result to -// `obs.observed_value` to decide whether the cached `SizeExprCacheEntry` is still valid. Each `obs.kind` -// mirrors the corresponding leaf in `evaluate_field_load` / `evaluate_external_tensor_shape` / -// `evaluate_external_tensor_read`. -int64_t replay_one_observation(const AdStackCache::SizeExprReadObservation &obs, - Program *prog, - LaunchContextBuilder *ctx) { - using Obs = AdStackCache::SizeExprReadObservation; - switch (obs.kind) { - case Obs::FieldLoadObs: { - // Gen-counter fast skip: when no kernel has bumped this SNode's write generation since record time, - // the underlying field value cannot have changed and we can return the recorded `observed_value` - // without dispatching a reader kernel. The dispatch is the dominant per-launch cost on the hot path - // for steady-state reverse-mode loops with stable bounds. - if (prog != nullptr && prog->adstack_cache().snode_write_gen(obs.snode_id) == obs.observed_gen) { - return obs.observed_value; - } - int64_t v = read_field_with_launch_cache(obs.snode_id, obs.indices, prog); - if (v == std::numeric_limits::min()) { - return obs.observed_value + 1; // force a mismatch if SNode disappeared - } - return v; - } - case Obs::ExternalShapeObs: { - if (ctx == nullptr) { - return obs.observed_value + 1; - } - std::vector arg_indices(obs.arg_id_path.begin(), obs.arg_id_path.end()); - arg_indices.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY); - arg_indices.push_back(obs.arg_shape_axis); - return static_cast(ctx->get_struct_arg_host(arg_indices)); - } - case Obs::ExternalReadObs: { - if (ctx == nullptr || obs.arg_id_path.empty()) { - return obs.observed_value + 1; - } - int arg_id = obs.arg_id_path[0]; - ArgArrayPtrKey key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; - auto it = ctx->array_ptrs.find(key); - if (it == ctx->array_ptrs.end()) { - return obs.observed_value + 1; - } - void *data_ptr = it->second; - // Gen-counter fast skip: when the data pointer is the same `DeviceAllocation *` we observed at record - // time AND its data generation has not been bumped since (no kernel write, no host-side `Ndarray.write` - // / `fill`), the underlying scalar cannot have changed and we can return the recorded value without - // dereferencing the device pointer (which on GPU would be a DtoH copy, on CPU a host load). - if (prog != nullptr && data_ptr == obs.observed_devalloc && - prog->adstack_cache().ndarray_data_gen(data_ptr) == obs.observed_gen) { - return obs.observed_value; - } - int64_t linear = 0; - int64_t stride = 1; - for (std::size_t i = obs.indices.size(); i > 0; --i) { - linear += static_cast(obs.indices[i - 1]) * stride; - if (i - 1 > 0) { - std::vector sh_idx(obs.arg_id_path.begin(), obs.arg_id_path.end()); - sh_idx.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY); - sh_idx.push_back(static_cast(i - 1)); - stride *= static_cast(ctx->get_struct_arg_host(sh_idx)); - } - } - switch (static_cast(obs.prim_dt)) { - case PrimitiveTypeID::i32: - return static_cast(static_cast(data_ptr)[linear]); - case PrimitiveTypeID::i64: - return static_cast(data_ptr)[linear]; - case PrimitiveTypeID::u32: - return static_cast(static_cast(data_ptr)[linear]); - case PrimitiveTypeID::u64: - return static_cast(static_cast(data_ptr)[linear]); - case PrimitiveTypeID::i16: - return static_cast(static_cast(data_ptr)[linear]); - case PrimitiveTypeID::u16: - return static_cast(static_cast(data_ptr)[linear]); - case PrimitiveTypeID::i8: - return static_cast(static_cast(data_ptr)[linear]); - case PrimitiveTypeID::u8: - return static_cast(static_cast(data_ptr)[linear]); - default: - return obs.observed_value + 1; - } - } - } - return obs.observed_value + 1; -} -} // namespace - -bool AdStackCache::try_size_expr_cache_hit(Program *prog, - const SerializedSizeExpr *expr_key, - LaunchContextBuilder *ctx, - int64_t &out_result) { - auto it = size_expr_cache_.find(expr_key); - if (it == size_expr_cache_.end()) { - return false; - } - const auto &entry = it->second; - for (const auto &obs : entry.reads) { - int64_t now = replay_one_observation(obs, prog, ctx); - if (now != obs.observed_value) { - size_expr_cache_.erase(it); - return false; - } - } - out_result = entry.result; - return true; -} - -void AdStackCache::record_size_expr_eval(const SerializedSizeExpr *expr_key, - int64_t result, - std::vector reads) { - size_expr_cache_[expr_key] = SizeExprCacheEntry{result, std::move(reads)}; -} - -bool AdStackCache::try_spirv_bytecode_cache_hit(Program *prog, - const void *attribs_key, - LaunchContextBuilder *ctx, - std::vector &out_bytecode) { - auto it = spirv_bytecode_cache_.find(attribs_key); - if (it == spirv_bytecode_cache_.end()) { - return false; - } - const auto &entry = it->second; - for (const auto &obs : entry.reads) { - int64_t now = replay_one_observation(obs, prog, ctx); - if (now != obs.observed_value) { - spirv_bytecode_cache_.erase(it); - return false; - } - } - out_bytecode = entry.bytecode; - return true; -} - -void AdStackCache::record_spirv_bytecode_eval(const void *attribs_key, - std::vector bytecode, - std::vector reads) { - spirv_bytecode_cache_[attribs_key] = SpirvBytecodeCacheEntry{std::move(bytecode), std::move(reads)}; -} - -void AdStackCache::record_per_task_ad_stack(const void *attribs_key, - std::vector metadata, - uint32_t stride_float, - uint32_t stride_int, - std::vector> snode_gens, - std::vector> arg_gens) { - per_task_ad_stack_cache_[attribs_key] = PerTaskAdStackCacheEntry{std::move(metadata), stride_float, stride_int, - std::move(snode_gens), std::move(arg_gens)}; -} - -bool AdStackCache::try_per_task_ad_stack_cache_hit(const void *attribs_key, - LaunchContextBuilder *ctx, - PerTaskAdStackCacheEntry &out) { - auto it = per_task_ad_stack_cache_.find(attribs_key); - if (it == per_task_ad_stack_cache_.end()) { - return false; - } - const auto &entry = it->second; - for (const auto &snode_pair : entry.snode_gens) { - if (snode_write_gen(snode_pair.first) != snode_pair.second) { - per_task_ad_stack_cache_.erase(it); - return false; - } - } - for (const auto &arg_tuple : entry.arg_gens) { - int arg_id = std::get<0>(arg_tuple); - void *recorded_devalloc = std::get<1>(arg_tuple); - uint64_t recorded_gen = std::get<2>(arg_tuple); - void *current_devalloc = nullptr; - if (ctx != nullptr) { - ArgArrayPtrKey key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; - auto ap_it = ctx->array_ptrs.find(key); - if (ap_it != ctx->array_ptrs.end()) { - current_devalloc = ap_it->second; - } - } - if (current_devalloc != recorded_devalloc) { - per_task_ad_stack_cache_.erase(it); - return false; - } - if (ndarray_data_gen(recorded_devalloc) != recorded_gen) { - per_task_ad_stack_cache_.erase(it); - return false; - } - } - out = entry; - return true; -} - -void AdStackCache::record_llvm_per_task_ad_stack(const void *attribs_key, - std::vector offsets, - std::vector max_sizes, - uint64_t stride_combined, - uint64_t stride_float, - uint64_t stride_int, - std::vector> snode_gens, - std::vector> arg_gens) { - llvm_per_task_ad_stack_cache_[attribs_key] = - LlvmPerTaskAdStackCacheEntry{std::move(offsets), std::move(max_sizes), stride_combined, stride_float, - stride_int, std::move(snode_gens), std::move(arg_gens)}; -} - -bool AdStackCache::try_llvm_per_task_ad_stack_cache_hit(const void *attribs_key, - LaunchContextBuilder *ctx, - LlvmPerTaskAdStackCacheEntry &out) { - auto it = llvm_per_task_ad_stack_cache_.find(attribs_key); - if (it == llvm_per_task_ad_stack_cache_.end()) { - return false; - } - const auto &entry = it->second; - for (const auto &snode_pair : entry.snode_gens) { - if (snode_write_gen(snode_pair.first) != snode_pair.second) { - llvm_per_task_ad_stack_cache_.erase(it); - return false; - } - } - for (const auto &arg_tuple : entry.arg_gens) { - int arg_id = std::get<0>(arg_tuple); - void *recorded_devalloc = std::get<1>(arg_tuple); - uint64_t recorded_gen = std::get<2>(arg_tuple); - void *current_devalloc = nullptr; - if (ctx != nullptr) { - ArgArrayPtrKey key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; - auto ap_it = ctx->array_ptrs.find(key); - if (ap_it != ctx->array_ptrs.end()) { - current_devalloc = ap_it->second; - } - } - if (current_devalloc != recorded_devalloc) { - llvm_per_task_ad_stack_cache_.erase(it); - return false; - } - if (ndarray_data_gen(recorded_devalloc) != recorded_gen) { - llvm_per_task_ad_stack_cache_.erase(it); - return false; - } - } - out = entry; - return true; -} - -// Per-thread backing for `SizeExprLaunchScope`. The outer scope on each thread points `t_launch_read_cache` here -// after clearing the map; nested scopes are no-ops. -thread_local LaunchScopedReadCache t_launch_read_cache_storage{}; - -SizeExprLaunchScope::SizeExprLaunchScope() : owns_(t_launch_read_cache == nullptr) { - if (owns_) { - t_launch_read_cache_storage.map.clear(); - t_launch_read_cache = &t_launch_read_cache_storage; - } -} -SizeExprLaunchScope::~SizeExprLaunchScope() { - if (owns_) { - t_launch_read_cache = nullptr; - } -} - -int64_t evaluate_adstack_size_expr(const SerializedSizeExpr &expr, Program *prog, LaunchContextBuilder *ctx) { - if (expr.nodes.empty()) { - return -1; - } - // Open a `SizeExprLaunchScope` if no enclosing one is active, so repeated reads within this eval share - // the launch read cache. Callers that issue several `evaluate_adstack_size_expr` calls back-to-back - // should open their own scope to span all of them. - SizeExprLaunchScope local_scope; - - // Cache fast path: replay the recorded reads against the live state and reuse the cached result if - // every input still matches. The full walk runs only on cache miss. - if (prog != nullptr) { - int64_t cached; - if (prog->adstack_cache().try_size_expr_cache_hit(prog, &expr, ctx, cached)) { - return cached; - } - } - std::unordered_map empty_bound_vars; - std::vector reads; - int64_t result = - evaluate_node(expr, static_cast(expr.nodes.size() - 1), empty_bound_vars, prog, ctx, &reads); - if (prog != nullptr) { - prog->adstack_cache().record_size_expr_eval(&expr, result, std::move(reads)); - } - return result; -} - -namespace { - -// Diagnose-time leaf reader: resolves an `ExternalTensorRead` against the captured -// `AdStackCache::DiagnoseLaunchSnapshot` and the program's `Device::map` interface. Returns -1 on any failure -// (missing arg in snapshot, unrecognised primitive type, mapping failure) so the caller can substitute the -// `?` placeholder for that stack while keeping the rest of the message intact. -// -// Single-scalar staging-buffer pattern (mirrors `Ndarray::read` in `program/ndarray.cpp`): allocate a tiny -// `host_read=true` staging buffer, `memcpy_internal` the one element from the ndarray's device buffer into -// staging, then map staging to read the value host-side. This works on every backend because every -// `Device` implementation supports `host_read=true` allocations + `map` + `memcpy_internal`. For `kNone` -// numpy passthrough the captured pointer is already host-readable; we read it directly. -int64_t read_diagnose_external_tensor(const SerializedSizeExprNode &node, - const std::vector &resolved_indices, - Program *prog, - const AdStackCache::DiagnoseLaunchSnapshot &snapshot) { - if (node.arg_id_path.empty()) { - return -1; - } - int arg_id = node.arg_id_path[0]; - auto ptr_it = snapshot.data_ptrs.find(arg_id); - if (ptr_it == snapshot.data_ptrs.end() || ptr_it->second == nullptr) { - return -1; - } - auto type_it = snapshot.dev_alloc_types.find(arg_id); - if (type_it == snapshot.dev_alloc_types.end()) { - return -1; - } - auto shape_it = snapshot.shapes.find(arg_id); - if (shape_it == snapshot.shapes.end()) { - return -1; - } - const std::vector &shape = shape_it->second; - // Compose C-order linear offset across resolved indices (mirrors `evaluate_external_tensor_read`'s stride - // math; we cannot share the helper because that one routes through `LaunchContextBuilder::get_struct_arg_host` - // which is not available here). - if (resolved_indices.size() > shape.size() && !shape.empty()) { - // More indices than rank - the size_expr was lowered against a different shape; skip. - return -1; - } - int64_t linear = 0; - int64_t stride = 1; - for (std::size_t i = resolved_indices.size(); i > 0; --i) { - linear += resolved_indices[i - 1] * stride; - if (i - 1 > 0 && i - 1 < shape.size()) { - stride *= static_cast(shape[i - 1]); - } - } - auto prim_dt = static_cast(node.const_value); - std::size_t elem_size = 0; - switch (prim_dt) { - case PrimitiveTypeID::i8: - case PrimitiveTypeID::u8: - elem_size = 1; - break; - case PrimitiveTypeID::i16: - case PrimitiveTypeID::u16: - elem_size = 2; - break; - case PrimitiveTypeID::i32: - case PrimitiveTypeID::u32: - elem_size = 4; - break; - case PrimitiveTypeID::i64: - case PrimitiveTypeID::u64: - elem_size = 8; - break; - default: - return -1; - } - std::size_t byte_offset = static_cast(linear) * elem_size; - // Decode the captured pointer to host bytes. - std::vector staging_bytes(elem_size); - if (type_it->second == LaunchContextBuilder::DevAllocType::kNone) { - // Numpy passthrough: ptr is already a raw host pointer. - const uint8_t *src = static_cast(ptr_it->second) + byte_offset; - std::memcpy(staging_bytes.data(), src, elem_size); - } else if (type_it->second == LaunchContextBuilder::DevAllocType::kNdarray) { - if (prog == nullptr) { - return -1; - } - auto *alloc = static_cast(ptr_it->second); - if (alloc == nullptr || alloc->device == nullptr) { - return -1; - } - Device::AllocParams params; - params.host_write = false; - params.host_read = true; - params.size = elem_size; - params.usage = AllocUsage::Storage; - auto [staging, alloc_res] = alloc->device->allocate_memory_unique(params); - if (alloc_res != RhiResult::success || !staging) { - return -1; - } - alloc->device->memcpy_internal(staging->get_ptr(), alloc->get_ptr(byte_offset), elem_size); - void *mapped = nullptr; - if (alloc->device->map(*staging, &mapped) != RhiResult::success || mapped == nullptr) { - return -1; - } - std::memcpy(staging_bytes.data(), mapped, elem_size); - alloc->device->unmap(*staging); - } else { - return -1; - } - // Sign- / zero-extend to int64 according to the captured primitive type. - switch (prim_dt) { - case PrimitiveTypeID::i8: - return static_cast(*reinterpret_cast(staging_bytes.data())); - case PrimitiveTypeID::u8: - return static_cast(*reinterpret_cast(staging_bytes.data())); - case PrimitiveTypeID::i16: - return static_cast(*reinterpret_cast(staging_bytes.data())); - case PrimitiveTypeID::u16: - return static_cast(*reinterpret_cast(staging_bytes.data())); - case PrimitiveTypeID::i32: - return static_cast(*reinterpret_cast(staging_bytes.data())); - case PrimitiveTypeID::u32: - return static_cast(*reinterpret_cast(staging_bytes.data())); - case PrimitiveTypeID::i64: - return *reinterpret_cast(staging_bytes.data()); - case PrimitiveTypeID::u64: - return static_cast(*reinterpret_cast(staging_bytes.data())); - default: - return -1; - } -} - -// Mirror of `evaluate_node` for diagnose-time evaluation. Same tree-walk semantics; differs only in the leaf -// case for `ExternalTensorRead` / `ExternalTensorShape`, which route through the snapshot + `Device::map` path -// instead of `LaunchContextBuilder`. Returns -1 on any leaf-resolution failure to short-circuit the rest of -// the walk and let the caller fall back to the static dual-cause body. -int64_t evaluate_node_for_diagnose(const SerializedSizeExpr &expr, - int32_t node_idx, - std::unordered_map &bound_vars, - Program *prog, - const AdStackCache::DiagnoseLaunchSnapshot &snapshot) { - if (node_idx < 0 || static_cast(node_idx) >= expr.nodes.size()) { - return -1; - } - const auto &node = expr.nodes[node_idx]; - switch (static_cast(node.kind)) { - case SizeExpr::Kind::Const: - return node.const_value; - case SizeExpr::Kind::FieldLoad: { - // Field reads stay on the existing host path - they do not depend on `LaunchContextBuilder` and the - // SNode reader-kernel dispatch is host-driven. We pass `nullptr` ReadSink so the recorded observations - // do not leak into the cache from a diagnose-only walk. - return evaluate_field_load(node, bound_vars, prog, /*reads=*/nullptr); - } - case SizeExpr::Kind::Add: { - int64_t a = evaluate_node_for_diagnose(expr, node.operand_a, bound_vars, prog, snapshot); - int64_t b = evaluate_node_for_diagnose(expr, node.operand_b, bound_vars, prog, snapshot); - if (a < 0 || b < 0) { - return -1; - } - return a + b; - } - case SizeExpr::Kind::Sub: { - int64_t a = evaluate_node_for_diagnose(expr, node.operand_a, bound_vars, prog, snapshot); - int64_t b = evaluate_node_for_diagnose(expr, node.operand_b, bound_vars, prog, snapshot); - if (a < 0 || b < 0) { - return -1; - } - return std::max(a - b, 0); - } - case SizeExpr::Kind::Mul: { - int64_t a = evaluate_node_for_diagnose(expr, node.operand_a, bound_vars, prog, snapshot); - int64_t b = evaluate_node_for_diagnose(expr, node.operand_b, bound_vars, prog, snapshot); - if (a < 0 || b < 0) { - return -1; - } - return a * b; - } - case SizeExpr::Kind::Max: { - int64_t a = evaluate_node_for_diagnose(expr, node.operand_a, bound_vars, prog, snapshot); - int64_t b = evaluate_node_for_diagnose(expr, node.operand_b, bound_vars, prog, snapshot); - if (a < 0 || b < 0) { - return -1; - } - return std::max(a, b); - } - case SizeExpr::Kind::MaxOverRange: { - int64_t begin = evaluate_node_for_diagnose(expr, node.operand_a, bound_vars, prog, snapshot); - int64_t end = evaluate_node_for_diagnose(expr, node.operand_b, bound_vars, prog, snapshot); - if (begin < 0 || end < 0) { - return -1; - } - // Same iteration cap as the live evaluator; refusing to enumerate prevents diagnose from stalling - // the error path on a pathological trip count. - constexpr int64_t kMaxOverRangeIterations = int64_t{1} << 24; - if (end > begin && end - begin > kMaxOverRangeIterations) { - return -1; - } - int64_t result = 0; - auto prev_it = bound_vars.find(node.var_id); - bool had_prev = prev_it != bound_vars.end(); - int64_t prev_val = had_prev ? prev_it->second : 0; - for (int64_t i = begin; i < end; ++i) { - bound_vars[node.var_id] = i; - int64_t v = evaluate_node_for_diagnose(expr, node.body_node_idx, bound_vars, prog, snapshot); - if (v < 0) { - if (had_prev) { - bound_vars[node.var_id] = prev_val; - } else { - bound_vars.erase(node.var_id); - } - return -1; - } - if (v > result) { - result = v; - } - } - if (had_prev) { - bound_vars[node.var_id] = prev_val; - } else { - bound_vars.erase(node.var_id); - } - return result; - } - case SizeExpr::Kind::BoundVariable: { - auto it = bound_vars.find(node.var_id); - if (it == bound_vars.end()) { - return -1; - } - return it->second; - } - case SizeExpr::Kind::ExternalTensorShape: { - if (node.arg_id_path.empty()) { - return -1; - } - int arg_id = node.arg_id_path[0]; - auto shape_it = snapshot.shapes.find(arg_id); - if (shape_it == snapshot.shapes.end()) { - return -1; - } - if (node.arg_shape_axis < 0 || static_cast(node.arg_shape_axis) >= shape_it->second.size()) { - return -1; - } - return static_cast(shape_it->second[node.arg_shape_axis]); - } - case SizeExpr::Kind::ExternalTensorRead: { - // Resolve indices from bound_vars first, then dispatch to the snapshot-aware reader. - std::vector resolved(node.indices.size()); - for (std::size_t i = 0; i < node.indices.size(); ++i) { - int32_t raw = node.indices[i]; - if (raw >= 0) { - resolved[i] = raw; - } else { - int32_t var_id = -(raw + 1); - auto bv = bound_vars.find(var_id); - if (bv == bound_vars.end()) { - return -1; - } - resolved[i] = bv->second; - } - } - return read_diagnose_external_tensor(node, resolved, prog, snapshot); - } - } - return -1; -} - -} // namespace - -int64_t evaluate_adstack_size_expr_for_diagnose(const SerializedSizeExpr &expr, Program *prog) { - if (expr.nodes.empty() || prog == nullptr) { - return -1; - } - const AdStackCache::DiagnoseLaunchSnapshot *snapshot = prog->adstack_cache().get_diagnose_snapshot(); - if (snapshot == nullptr) { - return -1; - } - std::unordered_map bound_vars; - return evaluate_node_for_diagnose(expr, static_cast(expr.nodes.size() - 1), bound_vars, prog, *snapshot); -} - -uint32_t AdStackCache::register_adstack_sizing_info(const void *identity_key, - const std::string &kernel_name, - int task_id_in_kernel, - std::vector allocated_max_sizes, - std::vector size_exprs) { - std::lock_guard lk(adstack_sizing_info_registry_mutex_); - // Idempotent re-registration: same `identity_key` yields the same id across re-compiles and updates the - // entry's metadata + size_exprs in place. The key is just an opaque dedup token - the registry never - // dereferences it; all data needed by the diagnose path is copied into the entry below. - auto it = adstack_sizing_info_id_by_ptr_.find(identity_key); - if (it != adstack_sizing_info_id_by_ptr_.end()) { - auto &entry = adstack_sizing_info_registry_[it->second]; - entry.kernel_name = kernel_name; - entry.task_id_in_kernel = task_id_in_kernel; - entry.allocated_max_sizes = std::move(allocated_max_sizes); - entry.size_exprs = std::move(size_exprs); - return it->second; - } - uint32_t id = static_cast(adstack_sizing_info_registry_.size()); - AdStackSizingInfoEntry entry; - entry.identity_key = identity_key; - entry.kernel_name = kernel_name; - entry.task_id_in_kernel = task_id_in_kernel; - entry.allocated_max_sizes = std::move(allocated_max_sizes); - entry.size_exprs = std::move(size_exprs); - adstack_sizing_info_registry_.push_back(std::move(entry)); - adstack_sizing_info_id_by_ptr_.emplace(identity_key, id); - return id; -} - -void AdStackCache::update_adstack_sizing_info_size_exprs(uint32_t id, std::vector size_exprs) { - std::lock_guard lk(adstack_sizing_info_registry_mutex_); - if (id == 0 || id >= adstack_sizing_info_registry_.size()) { - return; - } - adstack_sizing_info_registry_[id].size_exprs = std::move(size_exprs); -} - -std::optional AdStackCache::lookup_adstack_sizing_info(uint32_t id) const { - std::lock_guard lk(adstack_sizing_info_registry_mutex_); - if (id == 0 || id >= adstack_sizing_info_registry_.size()) { - return std::nullopt; - } - return adstack_sizing_info_registry_[id]; -} - -std::string AdStackCache::diagnose_adstack_overflow_message(uint32_t task_id) const { - return diagnose_adstack_overflow(task_id).message; -} - -AdStackCache::AdStackOverflowDiagnosis AdStackCache::diagnose_adstack_overflow(uint32_t task_id) const { - // Lazy LLVM capture: if the launcher stashed a pending ctx pointer for this launch (LLVM defers eager - // capture to avoid the per-launch snapshot cost), capture now before walking size_exprs. SPIR-V already - // captured eagerly at launch, so `pending_launch_ctx_` is null there. - if (pending_launch_ctx_ != nullptr) { - const_cast(this)->capture_diagnose_snapshot(*pending_launch_ctx_); - } - std::string identity_block; - std::string disambiguation_block; - // Cause classifier: when the synchronous re-run produces required > allocated for ANY stack, the most likely - // cause is an untracked tensor mutation (DLPack-bypass etc.). When all required <= allocated, the pre-pass - // undersized the bound (Quadrants bug). When we cannot re-evaluate (e.g. no captured launch snapshot, or a - // leaf type the diagnose evaluator does not support) we fall through to the static dual-cause body. - enum class Cause { Unknown, DLPackBypass, QuadrantsBug }; - Cause cause = Cause::Unknown; - - if (task_id != 0) { - auto entry_opt = lookup_adstack_sizing_info(task_id); - if (entry_opt.has_value()) { - const auto &entry = *entry_opt; - identity_block = " Offending task: kernel `" + entry.kernel_name + "` offload task #" + - std::to_string(entry.task_id_in_kernel) + "; per-stack allocated max_size = ["; - for (size_t i = 0; i < entry.allocated_max_sizes.size(); ++i) { - if (i != 0) { - identity_block += ", "; - } - identity_block += std::to_string(entry.allocated_max_sizes[i]); - } - identity_block += "].\n"; - - // Synchronous sizer rerun: walk each stack's `SerializedSizeExpr` and evaluate against the live host / - // SNode state. Stacks whose tree contains an `ExternalTensorShape` or `ExternalTensorRead` leaf go - // through the snapshot-based `evaluate_adstack_size_expr_for_diagnose` (see its declaration for the - // `Device::map` design rationale). Pure host-resolvable trees go through the standard host evaluator. - // The disambiguation is best-effort: if every stack's tree resolves we get a precise classification; - // otherwise we report what we have and fall back to the static dual-cause hint. - if (!entry.size_exprs.empty()) { - std::vector required_sizes; - std::vector required_known; - size_t any_grew = 0; - size_t any_unknown = 0; - size_t total = std::min(entry.size_exprs.size(), entry.allocated_max_sizes.size()); - for (size_t i = 0; i < total; ++i) { - const auto &expr = entry.size_exprs[i]; - bool host_resolvable = true; - for (const auto &node : expr.nodes) { - auto k = static_cast(node.kind); - if (k == SizeExpr::Kind::ExternalTensorShape || k == SizeExpr::Kind::ExternalTensorRead) { - host_resolvable = false; - break; - } - } - int64_t v = -1; - if (host_resolvable && !expr.nodes.empty()) { - // Pure host-resolvable: SNode field loads, constants, arithmetic. `ctx == nullptr` is safe because - // every leaf we kept is host-resolvable; ETS / ETR are the only kinds that touch ctx and we - // filtered them out. - SizeExprLaunchScope scope; - v = evaluate_adstack_size_expr(expr, prog_, nullptr); - } else if (!expr.nodes.empty()) { - // Tree contains ETR / ETS leaves. The diagnose evaluator resolves them through the captured launch - // snapshot (`Device::map`-based ndarray reads). On failure (no snapshot, allocation cannot be - // mapped, unsupported dtype) the helper returns -1 and we fall through to the `?` placeholder. - int64_t diag = evaluate_adstack_size_expr_for_diagnose(expr, prog_); - if (diag >= 0) { - v = diag; - } - } - required_sizes.push_back(v); - required_known.push_back(!expr.nodes.empty() && v >= 0); - if (required_known.back() && static_cast(v) > entry.allocated_max_sizes[i]) { - ++any_grew; - } - if (!required_known.back()) { - ++any_unknown; - } - } - if (any_grew > 0) { - cause = Cause::DLPackBypass; - } else if (any_unknown == 0 && total > 0) { - cause = Cause::QuadrantsBug; - } - // Only print the rerun line when at least one stack's bound resolves to a real value. With every leaf - // unresolved the line would be `required = [?, ?, ...]` which adds zero signal beyond the dual-cause - // body that follows; the omission keeps the message focused on actionable content. - if (any_unknown < total) { - disambiguation_block = " Synchronous sizer rerun: required max_size = ["; - for (size_t i = 0; i < required_sizes.size(); ++i) { - if (i != 0) { - disambiguation_block += ", "; - } - if (required_known[i]) { - disambiguation_block += std::to_string(required_sizes[i]); - } else { - disambiguation_block += "?"; - } - } - disambiguation_block += "]."; - if (any_unknown > 0) { - disambiguation_block += - " (`?` = sizer rerun could not resolve this stack's bound against the captured " - "launch state)."; - } - disambiguation_block += "\n"; - } - } - } - } - - std::string body; - if (cause == Cause::DLPackBypass) { - body = - "Cause (sync sizer rerun): a tensor backing a data-dependent loop bound was mutated outside " - "Quadrants's tracking - typically a DLPack zero-copy mutation through a torch tensor sharing " - "storage with a Quadrants ndarray, or a raw pointer write through a non-torch DLPack consumer. " - "The cached adstack capacity was sized against the value before the mutation. Recovery: route " - "the mutation through Quadrants APIs (`Ndarray.write` / `fill` / kernel writes) so the cache " - "invalidates correctly, OR set a generous initial cap if a workload-change milestone genuinely " - "grew capacity. Restart the iteration / training loop from a clean state.\n"; - } else if (cause == Cause::QuadrantsBug) { - body = - "Cause (sync sizer rerun): the freshly-computed required size does not exceed the allocated " - "size for any stack - this is a Quadrants bug. The pre-pass resolved the alloca to a bound " - "tighter than the actual runtime push count: either the enclosing loop shape is outside the " - "current `SizeExpr` grammar, or the Bellman-Ford analyzer undercounted the forward-pass " - "accumulation. Please file with the kernel IR (`QD_DUMP_IR=1`).\n"; - } else { - body = - "Two possible causes (synchronous sizer rerun was not conclusive - some `SizeExpr` trees " - "depend on ndarray contents that are not host-resolvable without a per-launch context, or the " - "task-id slot was empty so the registry pointer could not be confirmed live):\n" - " 1. A tensor backing a data-dependent loop bound was mutated outside Quadrants's tracking " - "(typically a DLPack zero-copy mutation through a torch tensor sharing storage with a " - "Quadrants ndarray, or a raw pointer write through a non-torch DLPack consumer). The cached " - "adstack capacity was sized against the value before the mutation. Recovery: route the " - "mutation through Quadrants APIs (`Ndarray.write` / `fill` / kernel writes) so the cache " - "invalidates correctly, OR set a generous initial cap if a workload-change milestone " - "genuinely grew capacity. Restart the iteration / training loop from a clean state.\n" - " 2. (Quadrants bug) the pre-pass resolved the alloca to a bound tighter than the actual " - "runtime push count - the enclosing loop shape is outside the current `SizeExpr` grammar, or " - "the Bellman-Ford analyzer undercounted the forward-pass accumulation. Please file with the " - "kernel IR (`QD_DUMP_IR=1`).\n"; - } - AdStackOverflowDiagnosis result; - result.message = identity_block + disambiguation_block + body + - "Note: kernel state may be inconsistent post-overflow; do not retry the same " - "step without addressing the cause and restarting from a clean state."; - // Flag the cache as confirmed-invalid only when the sync rerun positively identified DLPack-bypass (`required - // > allocated` for at least one stack with every leaf resolved against the live snapshot). Unknown is a rare - // fallback now that the snapshot-based evaluator handles ndarray-bound leaves; treating it as - // confirmed-bypass would silently retry against a possibly-broken cache. Quadrants-bug is excluded for the - // same reason - the next launch would re-run the same wrong sizer and produce the same wrong bound. - result.confirmed_invalid_cache = (cause == Cause::DLPackBypass); - return result; -} - -void AdStackCache::capture_diagnose_snapshot(const LaunchContextBuilder &ctx) { - diagnose_snapshot_.data_ptrs.clear(); - diagnose_snapshot_.dev_alloc_types.clear(); - diagnose_snapshot_.shapes.clear(); - // Pull just the data-pointer slot for each arg; the grad-pointer slot is irrelevant to size_expr leaves. - for (const auto &kv : ctx.array_ptrs) { - if (kv.first.ptr_type == TypeFactory::DATA_PTR_POS_IN_NDARRAY) { - diagnose_snapshot_.data_ptrs[kv.first.arg_id] = kv.second; - } - } - diagnose_snapshot_.dev_alloc_types = ctx.device_allocation_type; - // Mirror the per-arg shape vectors `LaunchContextBuilder` populated alongside the args-buffer writes. Going - // through this side map rather than `args_type->get_element_offset` avoids the spurious "Cannot treat as - // TensorType" diagnostics emitted when an axis lookup overruns the actual rank, and keeps the diagnose path - // independent of `args_type` lifetime. - for (const auto &kv : ctx.ndarray_shapes) { - std::vector shape32(kv.second.begin(), kv.second.end()); - diagnose_snapshot_.shapes[kv.first] = std::move(shape32); - } - diagnose_snapshot_.valid = true; -} - -const AdStackCache::DiagnoseLaunchSnapshot *AdStackCache::get_diagnose_snapshot() const { - return diagnose_snapshot_.valid ? &diagnose_snapshot_ : nullptr; -} - -void clip_effective_rows_by_loop_trip_count(std::size_t &effective_rows, - const StaticAdStackBoundExpr &bound_expr, - std::size_t dispatched_threads_ceiling, - Program *prog, - LaunchContextBuilder *ctx) { - if (bound_expr.loop_iter_static > 0) { - // Compile-time trip count: integer compare, no per-launch eval cost. Constant `SizeExpr` shapes are - // already collapsed into this field by the analyzer so they short-circuit the runtime eval below. - const std::size_t loop_iter_static = static_cast(bound_expr.loop_iter_static); - if (loop_iter_static <= dispatched_threads_ceiling) { - effective_rows = std::min(effective_rows, loop_iter_static); - } - return; - } - if (bound_expr.loop_iter_size_expr.nodes.empty() || prog == nullptr || ctx == nullptr) { - // Runtime tree empty or no resolution context: the analyzer left this field unset for shapes the - // compile-time path could not cover (or the caller did not supply a `Program` / `LaunchContextBuilder`), - // so leave `effective_rows` alone and let the caller fall back to the unclipped reducer count. - return; - } - // Runtime-bounded clip: evaluate the captured trip-count `SizeExpr` only when the static field is unset - // (the analyzer leaves `loop_iter_static == 0` for shapes the compile-time path cannot cover, e.g. - // `for j in range(field[i])` / `for k in range(arr.shape[axis])`). Cost = one tree walk per launch, - // dominated by host scalar reads through `SNodeRwAccessorsBank` on `FieldLoad` / `ExternalTensorRead` - // nodes (CPU: a memory load; CUDA / AMDGPU: a 4-8 byte DtoH). The evaluator returns -1 when the tree - // references state that is not host-resolvable from `ctx`; in that case we leave `effective_rows` - // unclipped from this source. - const int64_t evaluated = evaluate_adstack_size_expr(bound_expr.loop_iter_size_expr, prog, ctx); - if (evaluated > 0 && static_cast(evaluated) <= dispatched_threads_ceiling) { - effective_rows = std::min(effective_rows, static_cast(evaluated)); - } -} - -namespace { - -// Shared back-end for both encoder variants. Takes already-populated stack headers (with -// `entry_size_bytes` / `max_size_compile_time` / `heap_kind` set per stack, `root_node_idx` defaulted to -// `-1`) plus the per-stack source trees, runs the tree-to-bytecode substitution-aware flattening, and -// returns the packed byte buffer ready to upload to a device scratch buffer. -std::vector encode_bytecode_common(std::vector stack_headers, - const std::vector &exprs, - Program *prog, - LaunchContextBuilder *ctx, - const FieldLoadDeviceEmitter &fl_emitter, - int max_nodes_per_stack = 0, - ReadSink *reads = nullptr) { - const std::size_t n_stacks = stack_headers.size(); - QD_ASSERT(exprs.size() == n_stacks); - - std::vector nodes; - std::vector indices; - nodes.reserve(n_stacks); - indices.reserve(n_stacks); - - const bool fieldload_stays_on_device = !fl_emitter.empty(); - for (std::size_t i = 0; i < n_stacks; ++i) { - auto &sh = stack_headers[i]; - const SerializedSizeExpr *expr = exprs[i]; - if (std::getenv("QD_DEBUG_ADSTACK")) { - fprintf(stderr, "[encode] stack[%zu]: expr=%p nodes=%zu max_size_ct=%u\n", i, (void *)expr, - expr ? expr->nodes.size() : 0, sh.max_size_compile_time); - if (expr) { - for (size_t n = 0; n < expr->nodes.size(); ++n) { - const auto &node = expr->nodes[n]; - fprintf(stderr, - "[encode] node[%zu]: kind=%d const=%lld snode_id=%d var_id=%d op_a=%d op_b=%d body=%d " - "axis=%d", - n, node.kind, (long long)node.const_value, node.snode_id, node.var_id, node.operand_a, node.operand_b, - node.body_node_idx, node.arg_shape_axis); - if (!node.arg_id_path.empty()) { - fprintf(stderr, " arg_id=["); - for (int32_t v : node.arg_id_path) - fprintf(stderr, "%d,", v); - fprintf(stderr, "]"); - } - if (!node.indices.empty()) { - fprintf(stderr, " idx=["); - for (int32_t v : node.indices) - fprintf(stderr, "%d,", v); - fprintf(stderr, "]"); - } - fprintf(stderr, "\n"); - } - // Host-side ground-truth evaluation: if the shader later writes a different `max_size`, the delta - // pinpoints a shader-side bug rather than a pre-pass / SerializedSizeExpr bug. Skip when the caller - // passes `ctx == nullptr` (C++-only test harnesses) and when an `ExternalTensorRead` leaf exists but - // `ctx->array_ptrs` has not been populated (the CPU launcher populates it via - // `set_host_accessible_ndarray_ptrs`; SPIR-V launchers use the device-side PSB path instead, so - // `array_ptrs` is empty and the host-eval would crash on the missing key). - if (!expr->nodes.empty() && prog != nullptr && ctx != nullptr) { - bool has_etr = false; - for (const auto &node : expr->nodes) { - if (static_cast(node.kind) == SizeExpr::Kind::ExternalTensorRead) { - has_etr = true; - break; - } - } - if (has_etr && ctx->array_ptrs.empty()) { - fprintf(stderr, "[encode] stack[%zu]: host_eval=skipped (ctx->array_ptrs empty)\n", i); - } else { - int64_t host_val = evaluate_adstack_size_expr(*expr, prog, ctx); - fprintf(stderr, "[encode] stack[%zu]: host_eval=%lld\n", i, (long long)host_val); - } - } - } - } - if (expr == nullptr || expr->nodes.empty()) { - // No symbolic bound captured - the device interpreter will route this slot to `max_size_compile_time`. - sh.root_node_idx = -1; - continue; - } - auto contains_device_leaf = compute_contains_device_leaf(*expr, fieldload_stays_on_device); - auto free_vars = compute_free_vars(*expr); - const std::size_t root_src_idx = expr->nodes.size() - 1; - QD_ASSERT_INFO(free_vars[root_src_idx].empty(), - "Adstack SizeExpr tree root for stack {} has {} free bound variable(s); a well-formed tree" - " must be closed at the root because no outer MaxOverRange scope exists at publish time.", - i, free_vars[root_src_idx].size()); - // Dense-remap the tree's `var_id`s before emitting device nodes: `var_id_counter` on the host is a monotonic - // per-alloca counter bumped at every chased non-const index / stash, so a complex reverse-mode kernel can - // exceed the device interpreter's fixed-size scope capacity even with modest nesting. The encoder hard-errors - // here rather than letting the interpreter silently drop binds and return wrong `max_size` values. - auto var_id_remap = build_dense_var_id_remap(*expr); - const int32_t mor_depth = compute_max_mor_nesting(*expr); - QD_ERROR_IF(mor_depth > spirv::kAdStackSizerMaxPendingFrames, - "Adstack SizeExpr for stack {} has MaxOverRange nesting depth {}, which exceeds the sizer shader's " - "`kAdStackSizerMaxPendingFrames` ({}) pending-frame capacity. Past this cap the shader's fixed-size " - "pending-frame stack would index out of bounds - SPIR-V private-storage OOB is UB. Shrink the " - "enclosing reverse-mode loop nesting or file a bug so the cap can be raised.", - i, mor_depth, spirv::kAdStackSizerMaxPendingFrames); - const std::size_t nodes_before = nodes.size(); - sh.root_node_idx = encode_subtree(*expr, static_cast(root_src_idx), contains_device_leaf, free_vars, - var_id_remap, prog, ctx, fl_emitter, nodes, indices, reads); - if (max_nodes_per_stack > 0) { - const std::size_t per_stack = nodes.size() - nodes_before; - QD_ERROR_IF(per_stack > static_cast(max_nodes_per_stack), - "Adstack SizeExpr for stack {} encodes {} device nodes, which exceeds the sizer shader's per-stack " - "`kAdStackSizerMaxNodesPerStack` ({}) scratch capacity. Shrink the reverse-mode loop shape or file a " - "bug - past this cap the on-device interpreter would silently truncate its private `values_arr` and " - "surface later as a mysterious adstack overflow.", - i, per_stack, max_nodes_per_stack); - } - } - - // Pack everything into a flat byte buffer: header | stack_headers | nodes | indices. - AdStackSizeExprDeviceHeader header{}; - header.n_stacks = static_cast(n_stacks); - header.total_nodes = static_cast(nodes.size()); - header.total_indices = static_cast(indices.size()); - header._pad = 0; - - const std::size_t bytes_header = sizeof(AdStackSizeExprDeviceHeader); - const std::size_t bytes_stack_headers = sizeof(AdStackSizeExprDeviceStackHeader) * n_stacks; - const std::size_t bytes_nodes = sizeof(AdStackSizeExprDeviceNode) * nodes.size(); - const std::size_t bytes_indices = sizeof(int32_t) * indices.size(); - const std::size_t total_bytes = bytes_header + bytes_stack_headers + bytes_nodes + bytes_indices; - - std::vector buffer(total_bytes); - std::size_t cursor = 0; - std::memcpy(buffer.data() + cursor, &header, bytes_header); - cursor += bytes_header; - if (bytes_stack_headers > 0) { - std::memcpy(buffer.data() + cursor, stack_headers.data(), bytes_stack_headers); - cursor += bytes_stack_headers; - } - if (bytes_nodes > 0) { - std::memcpy(buffer.data() + cursor, nodes.data(), bytes_nodes); - cursor += bytes_nodes; - } - if (bytes_indices > 0) { - std::memcpy(buffer.data() + cursor, indices.data(), bytes_indices); - cursor += bytes_indices; - } - QD_ASSERT(cursor == total_bytes); - return buffer; -} - -} // namespace - -std::vector encode_adstack_size_expr_device_bytecode(const AdStackSizingInfo &ad_stack, - Program *prog, - LaunchContextBuilder *ctx) { - const std::size_t n_stacks = ad_stack.allocas.size(); - std::vector stack_headers(n_stacks); - std::vector exprs(n_stacks, nullptr); - for (std::size_t i = 0; i < n_stacks; ++i) { - stack_headers[i].entry_size_bytes = static_cast(ad_stack.allocas[i].entry_size_bytes); - stack_headers[i].max_size_compile_time = static_cast(ad_stack.allocas[i].max_size_compile_time); - // Float allocas land on the lazy float heap, int allocas on the eager int heap. The encoding (`0` = float, `1` = - // int) matches the SPIR-V `AdStackHeapKind` so the offline-cache bytecode survives a backend swap. - stack_headers[i].heap_kind = (ad_stack.allocas[i].heap_kind == AdStackAllocaInfo::HeapKind::Float) ? 0u : 1u; - if (i < ad_stack.size_exprs.size()) - exprs[i] = &ad_stack.size_exprs[i]; - } - // LLVM path: default-constructed emitter routes every FieldLoad through the host-fold (via `read_int`). That - // is safe on CPU / CUDA / AMDGPU where a nested accessor kernel launch does not conflict with the enclosing - // kernel prep. - FieldLoadDeviceEmitter fl_emitter{}; - return encode_bytecode_common(std::move(stack_headers), exprs, prog, ctx, fl_emitter); -} - -// Dense-only element-stride + place-offset computation for a `place` leaf snode. Returns false when the chain -// includes a non-dense snode (bitmasked / pointer / hash / bit-level), a shape with any axis <= 0, or a stride -// that would overflow an `int32`. Success writes `*out_elem_strides` in index order (same order as -// `SerializedSizeExprNode::indices`, each entry is the stride in element units of the leaf's primitive type, -// not bytes) and returns the byte offset of the place within its owning tree via `*out_place_byte_offset_in_root` -// so the caller can fold it into the encoded `base_psb` once and avoid a per-load add. -bool compute_dense_snode_strides(SNode *leaf, std::vector *out_elem_strides) { - if (leaf == nullptr) { - return false; - } - if (leaf->type != SNodeType::place) { - return false; - } - if (!leaf->is_path_all_dense) { - // A pointer / bitmasked / hash ancestor requires an on-device activation lookup the sizer shader does not - // implement. Pushing that into the shader would mean pulling the full SNode codegen subsystem in; refuse here - // and let the caller raise a user-visible "dense only" error. - return false; - } - if (leaf->is_bit_level) { - return false; // quant array / bit-struct leaves need bit-packing logic we do not emit here - } - // Refuse multi-child dense parents. The stride computation below assumes the place leaf is the sole - // occupant of its parent dense cell: `prod(shape[k+1..])` is a valid element-unit stride only when the - // physical cell size equals `sizeof(leaf_dtype)`. With multiple `.place(...)` siblings under the same - // dense ancestor (AoS layout), the real per-axis element stride is `cell_size / sizeof(leaf_dtype)`, so - // this function's output would land on a sibling field at `i >= 1`. Extending to cell-size-aware strides - // would require walking `SNodeDescriptor` memory-offset metadata the sizer shader does not consume today; - // refuse and surface a clear "dense-only, single-place parent" error instead. - for (const SNode *anc = leaf; anc != nullptr && anc->parent != nullptr; anc = anc->parent) { - const SNode *p = anc->parent; - if (p->type == SNodeType::dense && p->ch.size() > 1) { - return false; - } - } - const int n = leaf->num_active_indices; - if (n < 0) { - return false; - } - // Scalar fields (`qd.field(dt, shape=())`) have `num_active_indices == 0`; the pre-pass emits a `FieldLoad` - // with an empty `indices` vector and the shader should just load `*base_psb` without any index computation. - // Return an empty strides vector - `compute_field_load_elem_index`'s loop iterates zero times and produces - // `elem_idx = 0`, which `psb_load_scalar` resolves to the exact place address. - std::vector shape(n, 0); - for (int a = 0; a < n; ++a) { - int s = leaf->shape_along_axis(a); - if (s <= 0) { - return false; - } - shape[a] = s; - } - out_elem_strides->resize(n); - for (int a = 0; a < n; ++a) { - int64_t stride = 1; - for (int b = a + 1; b < n; ++b) { - stride *= shape[b]; - if (stride > std::numeric_limits::max()) { - return false; // would overflow the i32 slot; refuse rather than encode a truncated stride - } - } - (*out_elem_strides)[a] = static_cast(stride); - } - return true; -} - -std::vector encode_adstack_size_expr_device_bytecode_for_spirv( - const spirv::TaskAttributes::AdStackSizingAttribs &ad_stack, - Program *prog, - LaunchContextBuilder *ctx) { - const std::size_t n_stacks = ad_stack.allocas.size(); - std::vector stack_headers(n_stacks); - std::vector exprs(n_stacks, nullptr); - for (std::size_t i = 0; i < n_stacks; ++i) { - const auto &a = ad_stack.allocas[i]; - // The SPIR-V heaps are element-indexed (f32 / i32), so `entry_size_bytes` in the device header would be - // misnamed if we set it to the byte count; the SPIR-V sizer shader interprets this field as element count - // and scales by `2` only for the `Float` heap (to cover primal + adjoint interleaved), matching the - // `running_offset_float += 2u * max_size` / `running_offset_int += max_size` convention the host path used - // to perform and the main-kernel code already bakes into its offset arithmetic. Stamp `1` here so the - // sizer's multiplication by `2` for the float heap lands exactly on `2 * max_size` and the int heap on - // `1 * max_size`. - stack_headers[i].entry_size_bytes = 1; - stack_headers[i].max_size_compile_time = a.max_size_compile_time; - stack_headers[i].heap_kind = static_cast(a.heap_kind); // Float = 0, Int = 1 - exprs[i] = &a.size_expr; - } - // SPIR-V path: emit `FieldLoad` as `kFieldLoad` device nodes so the sizer shader can PSB-load the field value - // in place. This avoids `SNodeRwAccessorsBank::Accessors::read_int`, whose nested accessor-kernel launch - // deadlocks inside MoltenVK's descriptor-set bind path when the outer launch has already opened its command - // buffer. The emitter resolves each snode's tree-root PSB via the program's compute device and pre-computes - // the per-axis byte strides from the dense snode shape. - FieldLoadDeviceEmitter fl_emitter{}; - fl_emitter.fetch = [prog](SNode *snode, uint64_t *out_base_psb, std::vector *out_elem_strides) -> bool { - if (snode == nullptr || prog == nullptr) { - return false; - } - if (!compute_dense_snode_strides(snode, out_elem_strides)) { - return false; - } - const int tree_id = snode->get_snode_tree_id(); - DevicePtr tree_root_devptr = prog->get_snode_tree_device_ptr(tree_id); - Device *dev = prog->get_compute_device(); - if (dev == nullptr) { - return false; - } - // `get_memory_physical_pointer` returns the Vulkan `bufferDeviceAddress` / Metal equivalent for the buffer - // that backs the snode tree's root. The place's byte offset within the tree comes from the compiled snode - // descriptor table (`snode_descriptors[id].mem_offset_in_parent_cell` walked up to root), NOT from - // `SNode::offset_bytes_in_parent_cell` which is a frontend-only field that stays zero on the SPIR-V path. - // Using the wrong offset silently reads a sibling field (typically the first `qd.field` declared in the - // program), which looks like a returning-zero bug at runtime. - uint64_t root_psb = dev->get_memory_physical_pointer(tree_root_devptr); - if (root_psb == 0) { - return false; - } - size_t place_byte_offset = prog->get_field_in_tree_offset(tree_id, snode); - if (std::getenv("QD_DEBUG_ADSTACK")) { - // Pull the live value via the RwAccessors as a ground-truth check on the `PSB + place_off` pair the - // sizer shader will consume: if the shader later reads a different i32 than `live_val`, the byte offset - // we encoded is wrong even though `shape_along_axis` lined up and the tree dispatcher emitted a sensible - // PSB base. `read_int` launches its own accessor kernel plus a `synchronize()`, which is safe here - // because the encoder runs outside any in-flight main-kernel launch. - std::vector idx_zero(snode->num_active_indices, 0); - int64_t live_val = prog->get_snode_rw_accessors_bank().get(snode).read_int(idx_zero); - fprintf(stderr, - "[fl.fetch] snode_id=%d type=%d dense=%d n_axes=%d tree_id=%d root_psb=0x%llx place_off=%zu live=%lld\n", - snode->id, (int)snode->type, (int)snode->is_path_all_dense, snode->num_active_indices, tree_id, - (unsigned long long)root_psb, place_byte_offset, (long long)live_val); - } - *out_base_psb = root_psb + static_cast(place_byte_offset); - return true; - }; - // Bytecode fast path: replay the recorded host-fold reads against the live state and reuse the cached - // bytecode if every input still matches. The full encode runs only on cache miss. - if (prog != nullptr) { - std::vector cached; - if (prog->adstack_cache().try_spirv_bytecode_cache_hit(prog, static_cast(&ad_stack), ctx, cached)) { - return cached; - } - } - std::vector reads; - std::vector bytecode = encode_bytecode_common(std::move(stack_headers), exprs, prog, ctx, fl_emitter, - spirv::kAdStackSizerMaxNodesPerStack, &reads); - if (prog != nullptr) { - prog->adstack_cache().record_spirv_bytecode_eval(static_cast(&ad_stack), bytecode, std::move(reads)); - } - return bytecode; -} - -void bump_writes_for_kernel_llvm(Program *prog, - LaunchContextBuilder *ctx, - const std::vector &offloaded_tasks) { - if (prog == nullptr) { - return; - } - auto bump_data_ptr = [&](int arg_id) { - ArgArrayPtrKey data_key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; - auto it = ctx->array_ptrs.find(data_key); - if (it != ctx->array_ptrs.end() && it->second != nullptr) { - prog->adstack_cache().bump_ndarray_data_gen(it->second); - } - }; - for (const auto &task : offloaded_tasks) { - for (int snode_id : task.snode_writes) { - prog->adstack_cache().bump_snode_write_gen(snode_id); - } - for (int arg_id : task.arr_writes) { - bump_data_ptr(arg_id); - } - // Read-only `DevAllocType::kNone` args also need a bump: the user's host array is either H2D-blitted to a - // temporary device buffer (CUDA / AMDGPU) or read directly (CPU), and in both cases the data pointer used as - // the cache key is stable across launches, so a content mutation the user performed outside Quadrants's - // tracking is invisible to the metadata cache without an explicit bump. Mirrors the SPIR-V `kone_h2d_blit` - // rule in `bump_writes_for_kernel_spirv`. - for (int arg_id : task.arr_reads) { - auto type_it = ctx->device_allocation_type.find(arg_id); - if (type_it == ctx->device_allocation_type.end() || - type_it->second != LaunchContextBuilder::DevAllocType::kNone) { - continue; - } - bump_data_ptr(arg_id); - } - } -} - -void bump_writes_for_kernel_llvm(Program *prog, - LaunchContextBuilder *ctx, - const std::vector> &snode_writes_per_task, - const std::vector> &arr_writes_per_task, - const std::vector> &arr_reads_per_task) { - if (prog == nullptr) { - return; - } - auto bump_data_ptr = [&](int arg_id) { - ArgArrayPtrKey data_key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; - auto it = ctx->array_ptrs.find(data_key); - if (it != ctx->array_ptrs.end() && it->second != nullptr) { - prog->adstack_cache().bump_ndarray_data_gen(it->second); - } - }; - for (const auto &task_snodes : snode_writes_per_task) { - for (int snode_id : task_snodes) { - prog->adstack_cache().bump_snode_write_gen(snode_id); - } - } - for (const auto &task_args : arr_writes_per_task) { - for (int arg_id : task_args) { - bump_data_ptr(arg_id); - } - } - // Read-only `DevAllocType::kNone` args: see the comment in the CUDA / AMDGPU overload for why CPU LLVM also - // needs the bump. Empty `arr_reads_per_task` is the legal cache-miss path (offline-cache load that did not - // capture per-task arr_reads); skip the loop without raising. - for (const auto &task_args : arr_reads_per_task) { - for (int arg_id : task_args) { - auto type_it = ctx->device_allocation_type.find(arg_id); - if (type_it == ctx->device_allocation_type.end() || - type_it->second != LaunchContextBuilder::DevAllocType::kNone) { - continue; - } - bump_data_ptr(arg_id); - } - } -} - -void bump_writes_for_kernel_spirv( - Program *prog, - LaunchContextBuilder *ctx, - const std::vector &task_attribs, - const std::vector, irpass::ExternalPtrAccess>> &arr_access) { - if (prog == nullptr) { - return; - } - for (const auto &task : task_attribs) { - for (int snode_id : task.snode_writes) { - prog->adstack_cache().bump_snode_write_gen(snode_id); - } - } - for (const auto &kv : arr_access) { - const std::vector &indices = kv.first; - uint32_t access = uint32_t(kv.second); - QD_ASSERT(indices.size() == 1); - int arg_id = indices[0]; - bool kernel_writes = (access & uint32_t(irpass::ExternalPtrAccess::WRITE)) != 0; - bool kone_h2d_blit = (access & uint32_t(irpass::ExternalPtrAccess::READ)) != 0 && - ctx->device_allocation_type[arg_id] == LaunchContextBuilder::DevAllocType::kNone; - if (!kernel_writes && !kone_h2d_blit) { - continue; - } - ArgArrayPtrKey data_key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; - auto it = ctx->array_ptrs.find(data_key); - if (it != ctx->array_ptrs.end()) { - prog->adstack_cache().bump_ndarray_data_gen(it->second); - } - } -} - -} // namespace quadrants::lang diff --git a/quadrants/program/adstack_size_expr_eval.h b/quadrants/program/adstack_size_expr_eval.h index ec8f777203..0e0919e2c1 100644 --- a/quadrants/program/adstack_size_expr_eval.h +++ b/quadrants/program/adstack_size_expr_eval.h @@ -1,417 +1,14 @@ #pragma once -#include -#include -#include -#include -#include -#include -#include - -#include "quadrants/codegen/llvm/llvm_compiled_data.h" -#include "quadrants/codegen/spirv/kernel_utils.h" -#include "quadrants/ir/adstack_size_expr.h" -#include "quadrants/program/program.h" -#include "quadrants/transforms/static_adstack_analysis.h" - -namespace quadrants::lang { - -class LaunchContextBuilder; -class Program; - -// Adstack-specific state owned by `Program` and routed through `program->adstack_cache().method(...)`. Holds two -// orthogonal pieces: -// 1. The per-task adstack-sizer metadata caches (SPIR-V + LLVM-GPU), the encoded SPIR-V bytecode cache, the -// per-launch SizeExpr-eval result cache, and the per-snode / per-DeviceAllocation generation counters that -// drive precise invalidation. -// 2. The adstack-overflow identity registry + diagnostic classifier that the codegen-emitted overflow path -// reads through (`Program::launch_kernel` populates `DiagnoseLaunchSnapshot`; the registry maps task ids -// to kernel + offload-task identities + per-stack capacities, and `diagnose_adstack_overflow` runs the -// synchronous sizer rerun against the captured snapshot to classify the failure mode). -// Both pieces are adstack-internal and lived in `Program` historically; consolidating them here keeps the -// `Program` surface focused on cross-feature program state. -class AdStackCache { - public: - // Back-reference to `Program` is used by the diagnose path to reach `evaluate_adstack_size_expr` / - // `evaluate_adstack_size_expr_for_diagnose` (free functions that take `Program *`) and by the registry methods - // to access `get_compute_device()` for `Device::map`-based ndarray reads. Stored as a raw pointer because - // `AdStackCache` is owned by `Program` and shares its lifetime - the back-ref cannot dangle. - explicit AdStackCache(Program *prog) : prog_(prog) { - } - - // One input read observed during a `evaluate_adstack_size_expr` walk. The cache entry records these so a subsequent - // lookup re-reads the same inputs and compares to `observed_value`; a single mismatch forces a full re-walk. - // `observed_gen` snapshots `snode_write_gen` (FieldLoadObs) or `ndarray_data_gen` (ExternalReadObs) at record - // time. The replay walk uses it as a fast-path short-circuit: if the gen counter has not advanced, the value - // cannot have changed and the dispatch (reader kernel for SNode reads, device-pointer deref for ndarray reads) - // is skipped. ExternalShapeObs reads the args buffer per launch (cheap host memory access), so it does not need - // a gen and leaves this field at 0. - struct SizeExprReadObservation { - enum Kind : uint8_t { FieldLoadObs, ExternalShapeObs, ExternalReadObs }; - Kind kind; - int snode_id; - std::vector indices; - std::vector arg_id_path; - int arg_shape_axis; - int prim_dt; - int64_t observed_value; - uint64_t observed_gen{0}; - void *observed_devalloc{nullptr}; - }; - struct SizeExprCacheEntry { - int64_t result; - std::vector reads; - }; - bool try_size_expr_cache_hit(Program *prog, - const SerializedSizeExpr *expr_key, - LaunchContextBuilder *ctx, - int64_t &out_result); - void record_size_expr_eval(const SerializedSizeExpr *expr_key, - int64_t result, - std::vector reads); - void invalidate_size_expr() { - size_expr_cache_.clear(); - } - - // Cache for encoded SPIR-V adstack-sizer bytecode. Same dep-tracking contract as `try_size_expr_cache_hit` but the - // cached payload is the encoded bytes rather than an integer. - struct SpirvBytecodeCacheEntry { - std::vector bytecode; - std::vector reads; - }; - bool try_spirv_bytecode_cache_hit(Program *prog, - const void *attribs_key, - LaunchContextBuilder *ctx, - std::vector &out_bytecode); - void record_spirv_bytecode_eval(const void *attribs_key, - std::vector bytecode, - std::vector reads); - void invalidate_spirv_bytecode() { - spirv_bytecode_cache_.clear(); - } - - // Per-task adstack metadata output cache for the SPIR-V on-device sizer. - struct PerTaskAdStackCacheEntry { - std::vector metadata; - uint32_t stride_float{0}; - uint32_t stride_int{0}; - std::vector> snode_gens; - std::vector> arg_gens; - }; - bool try_per_task_ad_stack_cache_hit(const void *attribs_key, - LaunchContextBuilder *ctx, - PerTaskAdStackCacheEntry &out); - void record_per_task_ad_stack(const void *attribs_key, - std::vector metadata, - uint32_t stride_float, - uint32_t stride_int, - std::vector> snode_gens, - std::vector> arg_gens); - void invalidate_per_task_ad_stack() { - per_task_ad_stack_cache_.clear(); - } - - // Per-task adstack metadata output cache for the LLVM-GPU on-device sizer (CUDA + AMDGPU). - struct LlvmPerTaskAdStackCacheEntry { - std::vector offsets; - std::vector max_sizes; - uint64_t stride_combined{0}; - uint64_t stride_float{0}; - uint64_t stride_int{0}; - std::vector> snode_gens; - std::vector> arg_gens; - }; - bool try_llvm_per_task_ad_stack_cache_hit(const void *attribs_key, - LaunchContextBuilder *ctx, - LlvmPerTaskAdStackCacheEntry &out); - void record_llvm_per_task_ad_stack(const void *attribs_key, - std::vector offsets, - std::vector max_sizes, - uint64_t stride_combined, - uint64_t stride_float, - uint64_t stride_int, - std::vector> snode_gens, - std::vector> arg_gens); - void invalidate_llvm_per_task_ad_stack() { - llvm_per_task_ad_stack_cache_.clear(); - } - - // Bulk-invalidate just the per-task adstack metadata caches on the overflow raise path. The - // `size_expr_cache_` and `spirv_bytecode_cache_` are intentionally NOT cleared: they self-validate via per-read - // observation walks on the next lookup, so a DLPack-bypass mutation surfaces there as a normal observation - // mismatch and triggers a fresh evaluation without explicit eviction. The per-task metadata caches need a - // force-drop because their gen-counter snapshots match when the user's mutation bypassed our tracking. - // Invalidation is bulk (every task) rather than targeted (just the offender) because a single shared DLPack / - // torch view can back multiple tasks in the same kernel queue: targeted invalidation would let the next launch - // hit a stale entry on a different task that reads the same now-mutated tensor and overflow again. - void invalidate_all_per_task() { - invalidate_per_task_ad_stack(); - invalidate_llvm_per_task_ad_stack(); - } - - uint64_t snode_write_gen(int snode_id) const { - auto it = snode_write_gen_.find(snode_id); - return it == snode_write_gen_.end() ? 0u : it->second; - } - void bump_snode_write_gen(int snode_id) { - ++snode_write_gen_[snode_id]; - } - uint64_t ndarray_data_gen(void *devalloc_ptr) const { - auto it = ndarray_data_gen_.find(devalloc_ptr); - return it == ndarray_data_gen_.end() ? 0u : it->second; - } - void bump_ndarray_data_gen(void *devalloc_ptr) { - ++ndarray_data_gen_[devalloc_ptr]; - } - // Drop a per-DeviceAllocation entry. Called from `Ndarray::~Ndarray()` so the holder address can be reused by a - // future allocation without inheriting the destroyed ndarray's stale generation. Leftover snapshots in - // `per_task_ad_stack_cache_` / `llvm_per_task_ad_stack_cache_` referencing the dropped key fall back to gen=0 - // on the next lookup (their stored snapshot will not match), which forces a fresh sizer dispatch and self-heals. - void erase_ndarray_data_gen(void *devalloc_ptr) { - ndarray_data_gen_.erase(devalloc_ptr); - } - - // ----------------------------------------------------------------------------------------------------------- - // Adstack-overflow identity registry + diagnostic classifier - // ----------------------------------------------------------------------------------------------------------- - // Codegen registers each `OffloadedTask::ad_stack` once per kernel compilation and bakes the assigned id as - // an immediate into the lazy-claim overflow path; on overflow the codegen emits `cmpxchg(0, id)` against the - // pinned-host task-id slot. The host raise site reads the slot and routes through - // `diagnose_adstack_overflow_message(id)` to look up the kernel name, task index, and per-stack metadata for - // an enriched error message. Pointer ownership stays with `OffloadedTask`; entries are added but not removed - // - the registry size is bounded by the number of adstack-bearing tasks compiled in the program's lifetime, - // typically dozens. The diagnose path NEVER dereferences `identity_key`; all size-expression data is stored - // inline (`size_exprs`) so the entry is self-contained and immune to lifetime issues from the underlying - // `AdStackSizingInfo` (LLVM) / `AdStackSizingAttribs` (SPIR-V) struct moves. - struct AdStackSizingInfoEntry { - const void *identity_key{nullptr}; - std::string kernel_name; - int task_id_in_kernel{0}; - std::vector allocated_max_sizes; - std::vector size_exprs; - }; - uint32_t register_adstack_sizing_info(const void *identity_key, - const std::string &kernel_name, - int task_id_in_kernel, - std::vector allocated_max_sizes, - std::vector size_exprs); - // Refresh just the `size_exprs` snapshot in an existing registry entry. Used by the LLVM launcher on the first - // launch of a task whose codegen-time registration could not capture size_exprs (the codegen-time - // `current_task->ad_stack` had not yet been finalized). No-op for `id == 0` and ids outside the registry range. - void update_adstack_sizing_info_size_exprs(uint32_t id, std::vector size_exprs); - // Returns a *copy* of the registry entry (not a pointer into the underlying vector) so the caller can safely - // hold the data across operations that might trigger another `register_adstack_sizing_info` and grow / reallocate - // the registry vector (e.g. `evaluate_adstack_size_expr` dispatching a reader kernel that compiles a fresh - // task). Returns `std::nullopt` for the sentinel id `0` and for out-of-range ids. - std::optional lookup_adstack_sizing_info(uint32_t id) const; - // Format a diagnostic message for an overflow signal. `task_id` is the value read from the pinned-host task-id - // slot (0 if no thread overflowed; otherwise the registry id of the first overflowing task). The `message` - // field is embedded into the `QuadrantsAssertionError` raised at the poll site. The `confirmed_invalid_cache` - // field is true only when the synchronous sizer rerun classified the failure as a stale-cache / - // DLPack-bypass case (`required > allocated` for at least one stack with every leaf resolved against the - // captured launch snapshot); the caller (LLVM `check_adstack_overflow` / SPIR-V `GfxRuntime::synchronize`) - // uses it to decide whether to bulk-invalidate the per-task metadata caches so the next launch auto-recovers. - // We deliberately do NOT invalidate on Unknown / Quadrants-bug because invalidating would mask sizer bugs and - // could let a never-confirmed cause silently retry against a possibly-broken cache. - struct AdStackOverflowDiagnosis { - std::string message; - bool confirmed_invalid_cache{false}; - }; - AdStackOverflowDiagnosis diagnose_adstack_overflow(uint32_t task_id) const; - // Convenience wrapper that returns just the message string; production code uses `diagnose_adstack_overflow` - // to also act on the confirmed-cause signal. - std::string diagnose_adstack_overflow_message(uint32_t task_id) const; - - // Snapshot of the most recent launch's context fields needed by `diagnose_adstack_overflow` to resolve - // ndarray-bound `SizeExpr` leaves (`ExternalTensorRead` / `ExternalTensorShape`) at error time, when the - // original `LaunchContextBuilder` is gone. Captured at the top of `Program::launch_kernel` BEFORE the - // launcher rewrites `array_ptrs` (the CPU launcher's `set_host_accessible_ndarray_ptrs` overwrites the - // `DeviceAllocation *` entry with a raw host pointer; capturing earlier keeps the original handle so the - // diagnose path can use the unified `Device::map` API instead of trusting backend-specific semantics). - // - // Design choice (vs. re-dispatching the on-device sizer at diagnose time): `Device::map` is virtual on - // every backend (CPU / CUDA / AMDGPU / Vulkan / Metal), so this snapshot-plus-map approach gets backend - // parity for free without re-entering the launcher's pipeline-setup machinery (compute pipelines / - // descriptor sets / command buffers / sync fences). The diagnose path stays out of the launch lifecycle. - struct DiagnoseLaunchSnapshot { - bool valid{false}; - // arg_id -> ctx->array_ptrs[(arg_id, DATA_PTR_POS_IN_NDARRAY)]. For `kNone` numpy passthrough this is a - // raw host pointer. For `kNdarray` (qd.ndarray) this is a `DeviceAllocation *` handle the diagnose path - // dereferences via `Device::map`. Captured before the CPU launcher's `set_host_accessible_ndarray_ptrs` - // overwrite so the handle is uniform across backends. - std::unordered_map data_ptrs; - std::unordered_map dev_alloc_types; - // Pre-extracted ndarray shapes (`ctx->get_struct_arg_host({arg_id, SHAPE_POS, axis})`) so the - // diagnose evaluator does not need a live `LaunchContextBuilder` to resolve `ExternalTensorShape` or - // multi-axis `ExternalTensorRead` strides. - std::unordered_map> shapes; - }; - // Capture the per-launch fields the diagnose evaluator needs (see `DiagnoseLaunchSnapshot`'s definition for - // the design rationale and field-by-field semantics). Called eagerly from `Program::launch_kernel` only on - // backends where the launch ctx is gone by the time overflow is detected (SPIR-V at `synchronize`); on LLVM - // backends the per-launch overflow poll runs while ctx is still in scope, so we stash the ctx pointer with - // `set_pending_launch_ctx` and let `diagnose_adstack_overflow` capture lazily on the (rare) overflow path. - void capture_diagnose_snapshot(const LaunchContextBuilder &ctx); - // Lazy-capture handoff: `Program::launch_kernel` on LLVM backends sets this to the in-scope ctx before - // forwarding into the launcher and clears it after the per-launch overflow poll returns. If the poll fires, - // `diagnose_adstack_overflow` reads the pointer and captures the snapshot just in time. Stored as a raw - // pointer because it is transient per-launch and never outlives the call frame that set it. - void set_pending_launch_ctx(const LaunchContextBuilder *ctx) { - pending_launch_ctx_ = ctx; - } - // Read-only accessor for the latest snapshot, used by `diagnose_adstack_overflow` to resolve ndarray-bound - // size_expr leaves. Returns `nullptr` when no launch has happened yet (e.g. a freshly constructed `Program` - // hits `synchronize` during teardown without a prior kernel launch). - const DiagnoseLaunchSnapshot *get_diagnose_snapshot() const; - - private: - Program *prog_{nullptr}; - std::unordered_map size_expr_cache_; - std::unordered_map spirv_bytecode_cache_; - std::unordered_map per_task_ad_stack_cache_; - std::unordered_map llvm_per_task_ad_stack_cache_; - std::unordered_map snode_write_gen_; - std::unordered_map ndarray_data_gen_; - - // Adstack-overflow identity registry storage. Index 0 is reserved as the "no overflow" sentinel so the - // codegen-emitted `cmpxchg(0, id)` cleanly distinguishes "task id recorded" from "slot still clean". The - // reverse lookup map (keyed by `identity_key`) keeps `register_adstack_sizing_info` idempotent across - // re-launches of the same kernel. - std::vector adstack_sizing_info_registry_{AdStackSizingInfoEntry{}}; - std::unordered_map adstack_sizing_info_id_by_ptr_; - mutable std::mutex adstack_sizing_info_registry_mutex_; - // Latest captured launch context snapshot for the diagnose path's ndarray-bound leaf resolution. See - // `DiagnoseLaunchSnapshot`'s comment above for why we capture in `Program::launch_kernel` before the launcher - // forwards. - // Single-threaded by construction: `capture_diagnose_snapshot` runs from `Program::launch_kernel` (Python - // launcher thread) and `get_diagnose_snapshot` runs from `diagnose_adstack_overflow` on the same thread; no - // mutex needed. The codegen-time identity registry above keeps its mutex because it is hit from compilation - // worker threads. - DiagnoseLaunchSnapshot diagnose_snapshot_; - // Transient ctx handoff for the lazy LLVM capture path. See `set_pending_launch_ctx`. - const LaunchContextBuilder *pending_launch_ctx_{nullptr}; -}; - -// Evaluates a compile-time captured `SerializedSizeExpr` against the current field state of `prog` and the -// per-launch argument values in `ctx`, returning the concrete adstack capacity for this launch. Scalar i32/i64 -// field loads are serviced by `SNodeRwAccessorsBank` (one reader-kernel dispatch each); ndarray-argument shapes -// are read from `ctx->get_struct_arg`; constants and arithmetic are folded in plain C++; `MaxOverRange` -// enumerates its range and takes the max of the body expression across the bound variable. Returns -1 when the -// expression is empty (no symbolic bound captured), signalling to the caller to use the compile-time fallback. -int64_t evaluate_adstack_size_expr(const SerializedSizeExpr &expr, Program *prog, LaunchContextBuilder *ctx); - -// Diagnose-time variant that evaluates the same `SerializedSizeExpr` against the captured -// `AdStackCache::DiagnoseLaunchSnapshot` rather than a live `LaunchContextBuilder`. Used by -// `AdStackCache::diagnose_adstack_overflow` to resolve `ExternalTensorRead` / `ExternalTensorShape` leaves at -// error time against the live (potentially mutated) ndarray contents, without needing the launch ctx that is -// gone by sync time on async backends. The cross-backend `Device::map(*allocation, &host_ptr)` path is the -// design pivot - see `AdStackCache::DiagnoseLaunchSnapshot`'s comment for the rationale (vs. re-dispatching -// the on-device sizer). Returns -1 if any leaf cannot be resolved (e.g. an arg_id missing from the snapshot, -// or an allocation whose `Device::map` fails); callers fall back to the static dual-cause body in that case. -int64_t evaluate_adstack_size_expr_for_diagnose(const SerializedSizeExpr &expr, Program *prog); - -// RAII guard opening a thread-local read-cache scope. Every nested `evaluate_adstack_size_expr` running inside the -// scope shares one cache, so repeated `(snode_id, indices)` reads share a single reader-kernel dispatch. Place around -// any block that calls `evaluate_adstack_size_expr` more than once back-to-back. -class SizeExprLaunchScope { - public: - SizeExprLaunchScope(); - ~SizeExprLaunchScope(); - SizeExprLaunchScope(const SizeExprLaunchScope &) = delete; - SizeExprLaunchScope &operator=(const SizeExprLaunchScope &) = delete; - - private: - bool owns_; -}; - -// Flattens every alloca's `SerializedSizeExpr` tree into the device-readable bytecode defined in -// `quadrants/ir/adstack_size_expr_device.h` and returns the raw bytes ready to upload to a device scratch buffer. -// Two transforms happen at encoding time: -// -// 1. Pre-substitution of host-resolvable subtrees. Any subtree whose leaves consist only of `Const`, -// `BoundVariable`, `FieldLoad`, and `ExternalTensorShape` nodes - i.e. nothing that requires an -// on-device pointer dereference - is collapsed to a single `Const` node by running the existing host -// evaluator over it. This routes `FieldLoad` through `SNodeRwAccessorsBank::read_int` (which itself -// handles device-to-host via a tiny reader kernel on GPU) and `ExternalTensorShape` through the kernel -// arg buffer that the host just wrote, so the device interpreter in `runtime.cpp` never has to walk -// an SNode tree or index into `args_type` - it only has to handle arithmetic plus -// `ExternalTensorRead`, which is the one leaf kind that actually needs device-resident memory. -// 2. `arg_buffer_offset` precomputation. Every surviving `ExternalTensorRead` carries the byte offset into -// `RuntimeContext::arg_buffer` where the referenced ndarray's data pointer lives, resolved here against -// `ctx->args_type->get_element_offset({arg_id, DATA_PTR_POS_IN_NDARRAY})`. The device interpreter does -// a direct `*(void **)(arg_buffer + offset)` to fetch the ndarray pointer at launch time - no map -// lookup, no `LaunchContextBuilder` touches from device code. -// -// Mixed subtrees that contain both an `ExternalTensorRead` and a `FieldLoad` are rejected with a hard error: -// the device interpreter does not support on-device SNode access, so a `FieldLoad` that cannot be lifted out -// to a host-resolvable `Const` has nowhere to run. The grammar today does not emit this combination and no -// user kernel has been observed to do so; the hard error pins the assumption so a future regression cannot -// slip past. -std::vector encode_adstack_size_expr_device_bytecode(const AdStackSizingInfo &ad_stack, - Program *prog, - LaunchContextBuilder *ctx); - -// SPIR-V-flavour encoder. Same transforms as the LLVM variant, but sources per-stack metadata from -// `TaskAttributes::AdStackSizingAttribs::allocas` (each entry has a `HeapKind` - `Float = 0`, `Int = 1` - -// that routes the stack onto the `AdStackHeapFloat` or `AdStackHeapInt` backing buffer on the host). The -// `heap_kind` field of each `AdStackSizeExprDeviceStackHeader` carries that selector into the shader; the -// shader splits the running-offset / stride computation into a float accumulator and an int accumulator so -// the output metadata buffer matches the layout the main kernel already reads today: -// `[stride_float, stride_int, (offset_i, max_size_i)*]`. The `entry_size_bytes` field is set to 1 on the -// SPIR-V path because the backing buffers are element-indexed (f32 / i32) rather than byte-indexed and the -// shader multiplies by `2` only for the `Float` heap (primal + adjoint interleaved) - see the running-offset -// arithmetic in `GfxRuntime::launch_kernel` for the convention this matches. -std::vector encode_adstack_size_expr_device_bytecode_for_spirv( - const spirv::TaskAttributes::AdStackSizingAttribs &ad_stack, - Program *prog, - LaunchContextBuilder *ctx); - -// Apply the captured per-task loop trip-count clip to `effective_rows`. Each loop iteration of an adstack -// task claims at most one row at the LCA-block, so the heap needs at most `trip_count` rows regardless of -// how many cells of an oversized gating SNode/ndarray the reducer counted. Two trip-count sources, picked -// in order: `bound_expr.loop_iter_static` (compile-time-known constant, integer compare) and -// `bound_expr.loop_iter_size_expr` (per-launch tree walk via `evaluate_adstack_size_expr`). Both are -// gated by `dispatched_threads_ceiling` so a `dynamic_gpu_range_for` that exceeds the dispatch cap and -// serialises iterations across threads (each thread reaches the LCA-block multiple times) does not -// accidentally undersize the heap; pass `std::numeric_limits::max()` to disable the -// ceiling. No-op when the static field is zero AND the SizeExpr is empty (the analyzer leaves both -// unset for shapes the compile-time path cannot cover) - the caller's pre-clip `effective_rows` is left -// unchanged so the runtime falls through to the unclipped reducer count. -void clip_effective_rows_by_loop_trip_count(std::size_t &effective_rows, - const StaticAdStackBoundExpr &bound_expr, - std::size_t dispatched_threads_ceiling, - Program *prog, - LaunchContextBuilder *ctx); - -// Adstack-cache invalidation bump. Called from each backend's kernel launcher BEFORE the per-task -// `publish_adstack_metadata` loop runs, so the per-task metadata cache (`Program::*PerTaskAdStackCacheEntry`) snapshots -// the latest counters at record time and the next lookup detects any drift. Two sources contribute: -// -// - SNode writes: every task in the kernel lists its compile-time `snode_writes` set (computed at codegen via -// `irpass::analysis::gather_snode_read_writes`), bumped per id; covers `SizeExpr::FieldLoad` cache invalidation. -// - ndarray data writes: every arg slot the kernel writes to (`OffloadedTask::arr_writes` on LLVM-GPU, the kernel- -// level `ctx_attribs.arr_access` WRITE bits on SPIR-V) bumps the bound `DeviceAllocation`'s data generation. -// SPIR-V also bumps on the `kNone` READ branch to catch host-driven mutations of raw numpy / torch buffers blitted -// between launches; covers `SizeExpr::ExternalTensorRead` invalidation. -// -// The two helpers share the same Program-level effect; their signatures differ only because the codegen-time write -// sets are stored in different per-backend structs. Forward-only kernels (no adstack tasks) still call these to keep -// counters monotone, which is cheap (one map insert per snode_id at most). -void bump_writes_for_kernel_llvm(Program *prog, - LaunchContextBuilder *ctx, - const std::vector &offloaded_tasks); -// CPU launcher overload: per-task snode_writes / arr_writes / arr_reads are stored as separate parallel vectors on -// the launcher `Context` rather than as `OffloadedTask` clones, for legacy reasons documented in the CPU `Context` -// struct. -void bump_writes_for_kernel_llvm(Program *prog, - LaunchContextBuilder *ctx, - const std::vector> &snode_writes_per_task, - const std::vector> &arr_writes_per_task, - const std::vector> &arr_reads_per_task); -void bump_writes_for_kernel_spirv( - Program *prog, - LaunchContextBuilder *ctx, - const std::vector &task_attribs, - const std::vector, irpass::ExternalPtrAccess>> &arr_access); - -} // namespace quadrants::lang +// Umbrella header for the adstack subsystem. The implementation was split across +// `quadrants/program/adstack/{cache,eval,max_reducer,device_bytecode,diagnose,write_gen}.{cpp,h}`; this +// header re-includes the per-stage headers so existing call sites (`#include +// "quadrants/program/adstack_size_expr_eval.h"`) compile unchanged. New call sites should prefer +// including the specific stage header they depend on. + +#include "quadrants/program/adstack/cache.h" +#include "quadrants/program/adstack/device_bytecode.h" +#include "quadrants/program/adstack/diagnose.h" +#include "quadrants/program/adstack/eval.h" +#include "quadrants/program/adstack/max_reducer.h" +#include "quadrants/program/adstack/write_gen.h" diff --git a/quadrants/python/export_lang.cpp b/quadrants/python/export_lang.cpp index 818510f5f1..2d352f4473 100644 --- a/quadrants/python/export_lang.cpp +++ b/quadrants/python/export_lang.cpp @@ -16,6 +16,7 @@ #include "quadrants/ir/expression_ops.h" #include "quadrants/ir/frontend_ir.h" #include "quadrants/ir/statements.h" +#include "quadrants/program/adstack_size_expr_eval.h" #include "quadrants/program/extension.h" #include "quadrants/program/ndarray.h" #include "quadrants/rhi/device_capability.h" @@ -413,7 +414,14 @@ void export_lang(py::module &m) { .def("get_graph_cache_used_on_last_call", &Program::get_graph_cache_used_on_last_call) .def("get_num_offloaded_tasks_on_last_call", &Program::get_num_offloaded_tasks_on_last_call) .def("get_graph_num_nodes_on_last_call", &Program::get_graph_num_nodes_on_last_call) - .def("get_graph_total_builds", &Program::get_graph_total_builds); + .def("get_graph_total_builds", &Program::get_graph_total_builds) + // Test-only introspection on the max-reducer dispatch counter. Leading underscore signals "internal, not part of + // the public Python API"; quadrants tests reach these via `impl.get_runtime().prog`. They are intentionally not + // surfaced on the user-facing `qd.*` namespace and not documented under `docs/`. + .def("_get_max_reducer_dispatch_count", + [](Program *program) { return program->adstack_cache().max_reducer_dispatch_count(); }) + .def("_reset_max_reducer_dispatch_count", + [](Program *program) { program->adstack_cache().reset_max_reducer_dispatch_count(); }); py::class_(m, "CompileResult") .def_property_readonly( diff --git a/quadrants/rhi/cuda/cuda_device.cpp b/quadrants/rhi/cuda/cuda_device.cpp index 7baca21095..8972d47b6a 100644 --- a/quadrants/rhi/cuda/cuda_device.cpp +++ b/quadrants/rhi/cuda/cuda_device.cpp @@ -1,6 +1,8 @@ #include "quadrants/rhi/cuda/cuda_device.h" #include "quadrants/rhi/llvm/device_memory_pool.h" +#include + #include "quadrants/jit/jit_module.h" namespace quadrants::lang { @@ -23,7 +25,13 @@ RhiResult CudaDevice::allocate_memory(const AllocParams ¶ms, DeviceAllocatio auto &mem_pool = DeviceMemoryPool::get_instance(Arch::cuda, true /*merge_upon_release*/); bool managed = params.host_read || params.host_write; + fprintf(stderr, "[trace cuda allocate_memory] size=%zu managed=%d host_read=%d host_write=%d export=%d usage=%d\n", + params.size, (int)managed, (int)params.host_read, (int)params.host_write, (int)params.export_sharing, + (int)params.usage); + fflush(stderr); void *ptr = mem_pool.allocate(params.size, DeviceMemoryPool::page_size, managed); + fprintf(stderr, "[trace cuda allocate_memory] pool.allocate returned ptr=%p\n", ptr); + fflush(stderr); if (ptr == nullptr) { return RhiResult::out_of_memory; } @@ -34,7 +42,11 @@ RhiResult CudaDevice::allocate_memory(const AllocParams ¶ms, DeviceAllocatio info.use_cached = false; info.use_preallocated = false; + fprintf(stderr, "[trace cuda allocate_memory] memset (cuMemsetD8) ptr=%p size=%zu\n", info.ptr, info.size); + fflush(stderr); CUDADriver::get_instance().memset((void *)info.ptr, 0, info.size); + fprintf(stderr, "[trace cuda allocate_memory] memset done\n"); + fflush(stderr); *out_devalloc = DeviceAllocation{}; out_devalloc->alloc_id = allocations_.size(); @@ -47,13 +59,20 @@ RhiResult CudaDevice::allocate_memory(const AllocParams ¶ms, DeviceAllocatio DeviceAllocation CudaDevice::allocate_memory_runtime(const LlvmRuntimeAllocParams ¶ms) { AllocInfo info; info.size = quadrants::iroundup(params.size, quadrants_page_size); + fprintf(stderr, "[trace cuda allocate_memory_runtime] requested_size=%zu rounded_size=%zu use_memory_pool=%d\n", + params.size, info.size, (int)params.use_memory_pool); + fflush(stderr); if (info.size == 0) { info.ptr = nullptr; } else if (params.use_memory_pool) { CUDADriver::get_instance().malloc_async((void **)&info.ptr, info.size, nullptr); + fprintf(stderr, "[trace cuda allocate_memory_runtime] malloc_async ptr=%p\n", info.ptr); + fflush(stderr); } else { info.ptr = DeviceMemoryPool::get_instance(Arch::cuda, true /*merge_upon_release*/).allocate_with_cache(this, params); + fprintf(stderr, "[trace cuda allocate_memory_runtime] allocate_with_cache ptr=%p\n", info.ptr); + fflush(stderr); if (!info.ptr) { DeviceAllocation fail_alloc; @@ -64,8 +83,13 @@ DeviceAllocation CudaDevice::allocate_memory_runtime(const LlvmRuntimeAllocParam } } - if (info.ptr) + if (info.ptr) { + fprintf(stderr, "[trace cuda allocate_memory_runtime] memset ptr=%p size=%zu\n", info.ptr, info.size); + fflush(stderr); CUDADriver::get_instance().memset((void *)info.ptr, 0, info.size); + fprintf(stderr, "[trace cuda allocate_memory_runtime] memset done\n"); + fflush(stderr); + } info.is_imported = false; info.use_cached = true; diff --git a/quadrants/runtime/amdgpu/kernel_launcher.cpp b/quadrants/runtime/amdgpu/kernel_launcher.cpp index fdb633164b..b32ee7a066 100644 --- a/quadrants/runtime/amdgpu/kernel_launcher.cpp +++ b/quadrants/runtime/amdgpu/kernel_launcher.cpp @@ -69,6 +69,9 @@ void KernelLauncher::launch_offloaded_tasks(LaunchContextBuilder &ctx, // the cleared counter and UINT32_MAX-defaulted capacity arrays. executor->publish_adstack_lazy_claim_buffers(offloaded_tasks.size()); } + // Max-reducer dispatch. Mirrors the CUDA launcher; results land in `current_max_reducer_results_` for + // `publish_adstack_metadata` to substitute. + executor->dispatch_max_reducers_for_tasks(offloaded_tasks, &ctx, context_pointer); std::size_t task_index = 0; for (const auto &task : offloaded_tasks) { int effective_grid_dim = task.grid_dim; diff --git a/quadrants/runtime/cpu/kernel_launcher.cpp b/quadrants/runtime/cpu/kernel_launcher.cpp index d2c3a97900..d077820055 100644 --- a/quadrants/runtime/cpu/kernel_launcher.cpp +++ b/quadrants/runtime/cpu/kernel_launcher.cpp @@ -29,6 +29,11 @@ void KernelLauncher::launch_offloaded_tasks(LaunchContextBuilder &ctx, } // Span every task's `publish_adstack_metadata` call below with one shared read cache. SizeExprLaunchScope launch_scope; + // Max-reducer dispatch. Runs before the per-task `publish_adstack_metadata` loop so the dispatched values are + // available to substitute into per-stack `SerializedSizeExpr` trees inside each per-task encoder call. Stashed on the + // executor (`current_max_reducer_results_`); `publish_adstack_metadata` reads it. Empty map on kernels with no + // captured specs - no per-launch overhead in that case. + executor->dispatch_max_reducers_for_tasks(ad_stacks, &ctx, /*device_runtime_context_ptr=*/nullptr); for (size_t i = 0; i < task_funcs.size(); ++i) { if (!ad_stacks[i].allocas.empty()) { executor->publish_adstack_metadata(ad_stacks[i], num_threads_per_task[i], &ctx); diff --git a/quadrants/runtime/cuda/kernel_launcher.cpp b/quadrants/runtime/cuda/kernel_launcher.cpp index b695e8180b..e449f4a861 100644 --- a/quadrants/runtime/cuda/kernel_launcher.cpp +++ b/quadrants/runtime/cuda/kernel_launcher.cpp @@ -81,6 +81,10 @@ void KernelLauncher::launch_offloaded_tasks(LaunchContextBuilder &ctx, // the cleared counter and UINT32_MAX-defaulted capacity arrays. executor->publish_adstack_lazy_claim_buffers(offloaded_tasks.size()); } + // Max-reducer dispatch. Runs before the per-task loop so each `publish_adstack_metadata` call sees the result map via + // the executor's `current_max_reducer_results_` and can substitute captured `MaxOverRange`s inside its encoder. Empty + // map (and zero per-launch overhead) when no task has captured specs. + executor->dispatch_max_reducers_for_tasks(offloaded_tasks, &ctx, device_context_ptr); std::size_t task_index = 0; for (const auto &task : offloaded_tasks) { int effective_grid_dim = task.grid_dim; diff --git a/quadrants/runtime/gfx/CMakeLists.txt b/quadrants/runtime/gfx/CMakeLists.txt index dfec98a6e7..e026d8d428 100644 --- a/quadrants/runtime/gfx/CMakeLists.txt +++ b/quadrants/runtime/gfx/CMakeLists.txt @@ -5,6 +5,7 @@ target_sources(gfx_runtime PRIVATE runtime.cpp adstack_bound_reducer_launch.cpp + adstack_max_reducer_launch.cpp adstack_sizer_launch.cpp snode_tree_manager.cpp kernel_launcher.cpp diff --git a/quadrants/runtime/gfx/adstack_max_reducer_launch.cpp b/quadrants/runtime/gfx/adstack_max_reducer_launch.cpp new file mode 100644 index 0000000000..cf6f9cedf8 --- /dev/null +++ b/quadrants/runtime/gfx/adstack_max_reducer_launch.cpp @@ -0,0 +1,529 @@ +// Max-reducer dispatch for SPIR-V backends. Extracted out of `runtime.cpp` for the same reason +// `adstack_bound_reducer_launch.cpp` is - keeps `GfxRuntime::launch_kernel` focused on the main-kernel record/submit +// flow. Conditional on at least one task in the kernel having a non-empty +// `TaskAttributes::AdStackSizingAttribs::max_reducer_specs`. Returns an empty map on devices missing PSB+Int64 caps or +// on kernels with no captured specs; the caller falls through to the per-thread sizer eval, whose `1<<24` cap then +// surfaces as a hard error via the device sizer's overflow-flag slot if the iteration count exceeds the cap. +// +// Per-spec mechanism: +// 1. Pack the cache key `(registry_id, stack_id, mor_node_idx)` and query `AdStackCache::try_max_reducer_cache_hit`. On +// hit, record the cached value in the result map and skip the dispatch. +// 2. On miss, host-evaluate the captured `begin` and `end` subtrees via `evaluate_adstack_size_expr_at_node` (The +// recognizer grammar guarantees both subtrees are closed-form). Skip with -1 length on resolution failure. +// 3. Encode the body subtree into the shared bytecode buffer via `encode_max_reducer_body_bytecode`. The encoder +// extracts reachable nodes in post-order, renumbers to dense `[0, body_node_count)` indices, copies referenced indices +// entries, and resolves each `kExternalTensorRead` leaf's `arg_buffer_offset` via the closure passed here. +// 4. Build the `AdStackMaxReducerParams` blob into the shared params buffer at descriptor-aligned offset. +// 5. Build a single cmdlist with one dispatch per missed spec (each binds the same args/output buffers but a per-spec +// slice of params + bytecode), submit_synced. +// 6. Map the output buffer, read each missed spec's i64 slot into the result map, and call +// `AdStackCache::record_max_reducer_eval` with the body's read observations + the dispatched value so the next launch +// can short-circuit on a generation match. +// +// Caller responsibility: invoke `dispatch_max_reducers` BEFORE `publish_adstack_metadata_spirv` and pass the returned +// map down so the per-task sizer / device sizer encoder can substitute results into per-stack `SerializedSizeExpr` +// trees via `substitute_precomputed_max_over_range`. + +#include "quadrants/runtime/gfx/runtime.h" + +#include +#include +#include +#include + +#include "quadrants/codegen/spirv/adstack_max_reducer_shader.h" +#include "quadrants/common/logging.h" +#include "quadrants/ir/adstack_size_expr_device.h" +#include "quadrants/ir/snode.h" +#include "quadrants/ir/type_factory.h" +#include "quadrants/program/adstack/device_bytecode.h" +#include "quadrants/program/launch_context_builder.h" +#include "quadrants/program/program.h" +#include "quadrants/rhi/device.h" + +namespace quadrants::lang { +namespace gfx { + +namespace { + +// Resolve the byte offset within the kernel arg buffer where an ndarray argument's `data_ptr` (u64) lives. Mirrors +// `adstack_bound_reducer_launch.cpp::resolve_ndarray_data_ptr_byte_offset`; centralised in a single helper per launcher +// TU to keep the layout knowledge pinned to one call site per backend. +size_t resolve_ndarray_data_ptr_byte_offset(LaunchContextBuilder &host_ctx, const std::vector &arg_id_path) { + QD_ASSERT_INFO(host_ctx.args_type != nullptr, + "adstack max reducer: LaunchContextBuilder::args_type is null; cannot resolve ndarray data " + "pointer offset for the captured spec"); + std::vector indices = arg_id_path; + indices.push_back(TypeFactory::DATA_PTR_POS_IN_NDARRAY); + return host_ctx.args_type->get_element_offset(indices); +} + +// Per-spec dispatch unit, populated from each captured `StaticAdStackMaxReducerSpec`. Pass 1 (`collect_specs`) only +// fills the cache-key / identity fields and the back-references to the source `SerializedSizeExpr` and spec; the +// substitution-aware `prepare_spec` step writes the host-eval-derived `length` / `per_axis_*` and the body bytecode +// once the spec's `dependent_mor_node_idxs` are all in `result`. Specs whose preparation fails (axis resolution +// failure, body grammar reject, body too large) flip `dropped` and are excluded from dispatch. +struct PendingMaxReducerDispatch { + uint64_t cache_key; + uint32_t registry_id; + int32_t stack_id; + int32_t mor_node_idx; + const SerializedSizeExpr *expr; + const StaticAdStackMaxReducerSpec *spec; + bool dispatched{false}; + bool dropped{false}; + uint32_t length{0}; + uint32_t num_axes{0}; + std::vector per_axis_length; + std::vector per_axis_begin; + std::vector body_bytecode; + uint32_t body_node_count{0}; + uint32_t indices_count{0}; + std::vector reads; +}; + +// Pack `(registry_id, stack_id, mor_node_idx)` into the same 64-bit key encoding `AdStackCache::pack_max_reducer_key` +// uses internally (low 32 bits = registry_id, mid 16 = stack_id, high 16 = mor_node_idx). Mirrored here rather than +// exposed as a public helper because the caller's need is limited to this TU. +uint64_t pack_max_reducer_key(uint32_t registry_id, int32_t stack_id, int32_t mor_node_idx) { + return (static_cast(registry_id) & 0xFFFFFFFFull) | ((static_cast(stack_id) & 0xFFFFull) << 32) | + ((static_cast(mor_node_idx) & 0xFFFFull) << 48); +} + +} // namespace + +MaxReducerResultMap GfxRuntime::dispatch_max_reducers(LaunchContextBuilder &host_ctx, + DeviceAllocationGuard *args_buffer, + const std::unordered_map &ndarray_allocs, + const std::vector &task_attribs) { + MaxReducerResultMap result; + + // The shader builder requires `spirv_has_physical_storage_buffer` (PSB body-leaf reads through the kernel arg + // buffer's data pointers) and `spirv_has_int64` (i64 arithmetic inside the body interpreter, plus i64 begin + // reassembly). On a device missing either cap, `build_adstack_max_reducer_spirv` returns an empty binary and the lazy + // pipeline init below would assert. Skip the dispatch entirely and return an empty result map; the caller's + // substitution helper then leaves every captured `MaxOverRange` in place, so the per-task sizer falls back to its + // existing capped host-eval path. This call is sequenced before `publish_adstack_metadata_spirv`'s own cap gate so + // the latter is not load-bearing for this entry-point; we recheck here independently. + if (!device_->get_caps().get(DeviceCapability::spirv_has_physical_storage_buffer) || + !device_->get_caps().get(DeviceCapability::spirv_has_int64)) { + return result; + } + + Program *prog = (program_impl_ != nullptr) ? program_impl_->program : nullptr; + AdStackCache *cache = (prog != nullptr) ? &prog->adstack_cache() : nullptr; + + // Pass 1: collect specs into pending. Cache hits go straight to `result`; misses go to pending with back-references + // to the source `SerializedSizeExpr` and `StaticAdStackMaxReducerSpec`. Host-evaluation of begin / end and body + // bytecode encoding is deferred to the per-level prepare step below, where each spec's + // `dependent_mor_node_idxs` have already been substituted into the working tree. + std::vector pending; + pending.reserve(task_attribs.size()); + for (size_t ti = 0; ti < task_attribs.size(); ++ti) { + const auto &attribs = task_attribs[ti]; + if (attribs.ad_stack.max_reducer_specs.empty()) { + continue; + } + // Lazily register the task with the Program-side identity registry; `publish_adstack_metadata_spirv` is idempotent + // for already-registered tasks. The cache-key encoding uses `registry_id` to disambiguate same-shape MORs across + // kernels, so the id has to exist before the first cache lookup. + auto &mutable_attribs = + const_cast(attribs.ad_stack); + if (mutable_attribs.registry_id == 0 && cache != nullptr) { + std::vector allocated_max_sizes; + std::vector size_exprs; + allocated_max_sizes.reserve(mutable_attribs.allocas.size()); + size_exprs.reserve(mutable_attribs.allocas.size()); + for (const auto &a : mutable_attribs.allocas) { + allocated_max_sizes.push_back(static_cast(a.max_size_compile_time)); + size_exprs.push_back(a.size_expr); + } + mutable_attribs.registry_id = cache->register_adstack_sizing_info( + static_cast(&mutable_attribs), /*kernel_name=*/std::string{}, static_cast(ti), + std::move(allocated_max_sizes), std::move(size_exprs)); + } + const uint32_t registry_id = mutable_attribs.registry_id; + if (registry_id == 0) { + continue; + } + for (const auto &spec : attribs.ad_stack.max_reducer_specs) { + const uint64_t key = pack_max_reducer_key(registry_id, spec.stack_id, spec.mor_node_idx); + if (cache != nullptr) { + int64_t cached; + if (cache->try_max_reducer_cache_hit(registry_id, spec.stack_id, spec.mor_node_idx, &host_ctx, cached)) { + result[key] = cached; + continue; + } + } + PendingMaxReducerDispatch p{}; + p.cache_key = key; + p.registry_id = registry_id; + p.stack_id = spec.stack_id; + p.mor_node_idx = spec.mor_node_idx; + p.expr = &attribs.ad_stack.allocas[spec.stack_id].size_expr; + p.spec = &spec; + pending.push_back(std::move(p)); + } + } + + if (pending.empty()) { + return result; + } + + // Lazy-init pipeline. Mirror `adstack_bound_reducer_launch.cpp`'s pattern: build the SPIR-V binary once via the + // shader-build helper, hand to the device's pipeline factory, cache for the runtime's lifetime. + if (!adstack_max_reducer_pipeline_) { + std::vector spirv = spirv::build_adstack_max_reducer_spirv(Arch::vulkan, &device_->get_caps()); + QD_ASSERT_INFO(!spirv.empty(), + "build_adstack_max_reducer_spirv returned an empty binary despite the PSB+Int64 cap " + "check passing; bug in the shader builder's capability gating"); + PipelineSourceDesc source_desc{PipelineSourceType::spirv_binary, (void *)spirv.data(), + spirv.size() * sizeof(uint32_t)}; + auto [pipeline, res] = device_->create_pipeline_unique(source_desc, "adstack_max_reducer", backend_cache_.get()); + QD_ERROR_IF(res != RhiResult::success, "Failed to create pipeline for the adstack max reducer (err: {})", int(res)); + adstack_max_reducer_pipeline_ = std::move(pipeline); + } + + // Slot-0 placeholder for kernels with no kernel arg buffer. Same RHI rule as the bound reducer: descriptor-set + // layouts require a non-null binding even if the shader's branch never reads it. + if (args_buffer == nullptr && !adstack_max_reducer_args_placeholder_buffer_) { + auto [buf, res] = device_->allocate_memory_unique({sizeof(uint32_t), /*host_write=*/false, /*host_read=*/false, + /*export_sharing=*/false, AllocUsage::Storage}); + QD_ASSERT_INFO(res == RhiResult::success, "Failed to allocate adstack max reducer slot-0 placeholder buffer"); + adstack_max_reducer_args_placeholder_buffer_ = std::move(buf); + } + + constexpr size_t kDescriptorOffsetAlignment = 256; + auto align_up = [](size_t v, size_t a) { return (v + a - 1) & ~(a - 1); }; + const size_t params_size_bytes = spirv::AdStackMaxReducerParams::kNumWords * sizeof(uint32_t); + auto grow_buffer = [&](std::unique_ptr &buf, size_t &capacity, size_t needed, bool host_write, + bool host_read, const char *label) { + if (buf && capacity >= needed) { + return; + } + size_t new_size = std::max(needed, 2 * capacity); + auto [new_buf, res] = device_->allocate_memory_unique( + {new_size, host_write, host_read, /*export_sharing=*/false, AllocUsage::Storage}); + QD_ASSERT_INFO(res == RhiResult::success, "Failed to allocate {} (size={})", label, new_size); + if (buf) { + ctx_buffers_.push_back(std::move(buf)); + } + buf = std::move(new_buf); + capacity = new_size; + }; + + // Level-based dispatch: each iteration picks every undispatched spec whose `dependent_mor_node_idxs` are all already + // in `result` (cache hits + earlier rounds), substitutes those values into the working tree, host-evaluates begin / + // end against the substituted tree, encodes the body bytecode, then dispatches the level's specs as a single batched + // cmdlist. Most kernels have specs without inter-spec dependencies and finish in one round; nested patterns (e.g. + // outer `MaxOverRange` whose end contains a previously-captured inner `max-of-array`) take one round per dependency + // depth. A round that picks no specs but has unprocessed pending entries breaks out via the `cycle / unresolvable` + // guard and leaves those entries dropped, falling through to the per-task device sizer. + size_t dispatched_count = 0; + size_t dropped_count = 0; + while (dispatched_count + dropped_count < pending.size()) { + std::vector level_indices; + for (size_t k = 0; k < pending.size(); ++k) { + if (pending[k].dispatched || pending[k].dropped) + continue; + bool deps_ok = true; + for (int32_t dep_node : pending[k].spec->dependent_mor_node_idxs) { + const uint64_t dep_key = pack_max_reducer_key(pending[k].registry_id, pending[k].stack_id, dep_node); + if (result.find(dep_key) == result.end()) { + deps_ok = false; + break; + } + } + if (deps_ok) + level_indices.push_back(k); + } + if (level_indices.empty()) { + // Cycle / unresolvable - no progress possible. Drop remaining and let the per-task sizer absorb them. + for (size_t k = 0; k < pending.size(); ++k) { + if (!pending[k].dispatched && !pending[k].dropped) { + pending[k].dropped = true; + ++dropped_count; + } + } + break; + } + + // Prepare each ready spec: substitute already-resolved deps' values into the tree, host-eval begin / end, encode + // body bytecode. Specs whose preparation fails (axis non-resolvable, length over u32 cap, body grammar reject) + // mark `dropped` and are skipped for this round and forever. + auto arg_buffer_offset_resolver = [&](const std::vector &arg_id_path) -> int32_t { + std::vector path(arg_id_path.begin(), arg_id_path.end()); + const size_t byte_off = resolve_ndarray_data_ptr_byte_offset(host_ctx, path); + if (byte_off > std::numeric_limits::max()) { + return -1; + } + return static_cast(byte_off); + }; + // SPIR-V FieldLoad-with-bound-var-index emitter: resolve `(snode tree root_psb + place_byte_offset_in_root)` plus + // per-active-axis element strides for each `kFieldLoad` body leaf. Mirrors the per-task sizer's emitter in + // `device_bytecode.cpp::encode_adstack_size_expr_device_bytecode_for_spirv`. The encoder folds the closed-FieldLoad + // path host-side (no emitter call) and routes only bound-var-indexed leaves through this closure. + Device *dev = device_; + FieldLoadDeviceEmitter fl_emitter{}; + fl_emitter.fetch = [prog, dev](SNode *snode, uint64_t *out_base_psb, + std::vector *out_elem_strides) -> bool { + if (snode == nullptr || prog == nullptr || dev == nullptr) { + return false; + } + if (!compute_dense_snode_strides(snode, out_elem_strides)) { + return false; + } + const int tree_id = snode->get_snode_tree_id(); + DevicePtr tree_root_devptr = prog->get_snode_tree_device_ptr(tree_id); + const uint64_t root_psb = dev->get_memory_physical_pointer(tree_root_devptr); + if (root_psb == 0) { + return false; + } + const size_t place_byte_offset = prog->get_field_in_tree_offset(tree_id, snode); + *out_base_psb = root_psb + static_cast(place_byte_offset); + return true; + }; + std::vector level_dispatch; + level_dispatch.reserve(level_indices.size()); + for (size_t k : level_indices) { + const auto *spec = pending[k].spec; + const std::size_t num_axes = spec->axis_var_ids.size(); + if (num_axes == 0 || num_axes > static_cast(kAdStackMaxReducerMaxAxes)) { + pending[k].dropped = true; + ++dropped_count; + continue; + } + // Substitute every already-resolved MOR in `result` (for this spec's stack) into a working copy of the tree, so + // begin / end host-evaluation sees the dependent specs as `kConst` instead of walking through them. + const SerializedSizeExpr substituted = + substitute_precomputed_max_over_range(*pending[k].expr, pending[k].registry_id, pending[k].stack_id, result); + std::vector per_axis_length_v(num_axes, 0); + std::vector per_axis_begin_v(num_axes, 0); + bool axes_ok = true; + uint64_t total_length = 1; + for (std::size_t a = 0; a < num_axes; ++a) { + const int64_t bv = + evaluate_adstack_size_expr_at_node(substituted, spec->axis_begin_node_idxs[a], prog, &host_ctx); + const int64_t ev = + evaluate_adstack_size_expr_at_node(substituted, spec->axis_end_node_idxs[a], prog, &host_ctx); + if (bv < 0 || ev < 0 || ev <= bv) { + axes_ok = false; + break; + } + const int64_t len = ev - bv; + if (len > std::numeric_limits::max()) { + axes_ok = false; + break; + } + per_axis_begin_v[a] = bv; + per_axis_length_v[a] = static_cast(len); + total_length *= static_cast(len); + if (total_length > std::numeric_limits::max()) { + axes_ok = false; + break; + } + } + if (!axes_ok) { + pending[k].dropped = true; + ++dropped_count; + continue; + } + EncodedMaxReducerBody encoded = + encode_max_reducer_body_bytecode(substituted, spec->body_node_idx, spec->axis_var_ids, + arg_buffer_offset_resolver, &host_ctx, prog, &fl_emitter); + if (encoded.body_node_count == 0 || encoded.body_node_count > spirv::kAdStackMaxReducerMaxBodyNodes) { + pending[k].dropped = true; + ++dropped_count; + continue; + } + pending[k].length = static_cast(total_length); + pending[k].num_axes = static_cast(num_axes); + pending[k].per_axis_length = std::move(per_axis_length_v); + pending[k].per_axis_begin = std::move(per_axis_begin_v); + pending[k].body_bytecode = std::move(encoded.bytes); + pending[k].body_node_count = encoded.body_node_count; + pending[k].indices_count = encoded.indices_count; + pending[k].reads = std::move(encoded.body_reads); + level_dispatch.push_back(k); + } + if (level_dispatch.empty()) { + continue; // every ready spec failed preparation; loop checks for more progress next iteration + } + + // Pack params + bytecode for this level. Output buffer holds two u32 slots per dispatched spec (`[value, + // overflow_flag]`); the spec's slot index in this round's output buffer is its position in `level_dispatch`. + std::vector per_spec_params_offsets(level_dispatch.size()); + std::vector per_spec_bytecode_word_offsets(level_dispatch.size()); + size_t total_params_bytes = 0; + size_t total_bytecode_bytes = 0; + for (size_t i = 0; i < level_dispatch.size(); ++i) { + const size_t k = level_dispatch[i]; + per_spec_params_offsets[i] = align_up(total_params_bytes, kDescriptorOffsetAlignment); + total_params_bytes = per_spec_params_offsets[i] + params_size_bytes; + QD_ASSERT_INFO(pending[k].body_bytecode.size() % sizeof(uint32_t) == 0, + "max-reducer body bytecode is not 4-byte aligned (size={})", pending[k].body_bytecode.size()); + per_spec_bytecode_word_offsets[i] = static_cast(total_bytecode_bytes / sizeof(uint32_t)); + total_bytecode_bytes += pending[k].body_bytecode.size(); + } + const size_t output_bytes = level_dispatch.size() * 2 * sizeof(uint32_t); + + grow_buffer(adstack_max_reducer_params_buffer_, adstack_max_reducer_params_buffer_size_, total_params_bytes, + /*host_write=*/true, /*host_read=*/false, "adstack max reducer params buffer"); + grow_buffer(adstack_max_reducer_bytecode_buffer_, adstack_max_reducer_bytecode_buffer_size_, total_bytecode_bytes, + /*host_write=*/true, /*host_read=*/false, "adstack max reducer bytecode buffer"); + grow_buffer(adstack_max_reducer_output_buffer_, adstack_max_reducer_output_buffer_size_, output_bytes, + /*host_write=*/false, /*host_read=*/true, "adstack max reducer output buffer"); + + // Write params + bytecode into their host-mapped buffers. + { + void *mapped = nullptr; + RhiResult map_res = + device_->map_range(adstack_max_reducer_params_buffer_->get_ptr(0), total_params_bytes, &mapped); + QD_ASSERT_INFO(map_res == RhiResult::success, "Failed to map adstack max reducer params buffer"); + for (size_t i = 0; i < level_dispatch.size(); ++i) { + const size_t k = level_dispatch[i]; + spirv::AdStackMaxReducerParams params{}; + params.output_slot = static_cast(i); + params.length = pending[k].length; + params.num_axes = pending[k].num_axes; + params.body_bytecode_offset_words = per_spec_bytecode_word_offsets[i]; + params.body_node_count = pending[k].body_node_count; + const uint32_t node_words = sizeof(AdStackSizeExprDeviceNode) / 4u; + params.body_indices_offset_words = per_spec_bytecode_word_offsets[i] + pending[k].body_node_count * node_words; + for (uint32_t a = 0; a < pending[k].num_axes; ++a) { + params.per_axis_length[a] = pending[k].per_axis_length[a]; + const uint64_t begin_u64 = static_cast(pending[k].per_axis_begin[a]); + params.per_axis_begin_lo[a] = static_cast(begin_u64 & 0xFFFFFFFFull); + params.per_axis_begin_hi[a] = static_cast((begin_u64 >> 32) & 0xFFFFFFFFull); + params.per_axis_var_id[a] = static_cast(a); + } + std::memcpy(reinterpret_cast(mapped) + per_spec_params_offsets[i], ¶ms, params_size_bytes); + } + device_->unmap(*adstack_max_reducer_params_buffer_); + } + if (total_bytecode_bytes > 0) { + void *mapped = nullptr; + RhiResult map_res = + device_->map_range(adstack_max_reducer_bytecode_buffer_->get_ptr(0), total_bytecode_bytes, &mapped); + QD_ASSERT_INFO(map_res == RhiResult::success, "Failed to map adstack max reducer bytecode buffer"); + char *base = reinterpret_cast(mapped); + size_t cursor = 0; + for (size_t i = 0; i < level_dispatch.size(); ++i) { + const size_t k = level_dispatch[i]; + std::memcpy(base + cursor, pending[k].body_bytecode.data(), pending[k].body_bytecode.size()); + cursor += pending[k].body_bytecode.size(); + } + device_->unmap(*adstack_max_reducer_bytecode_buffer_); + } + + flush(); + device_->wait_idle(); + + // GPU-side clear of the output buffer. Apple Silicon Metal leaves a host-side `map_range` + memset clear sitting in + // a write-combined cache that the next compute pipeline read does not observe; a `buffer_fill` is sequenced by the + // compute queue. + auto [clear_cmdlist, clear_cmdlist_res] = device_->get_compute_stream()->new_command_list_unique(); + QD_ASSERT_INFO(clear_cmdlist_res == RhiResult::success, "Failed to create adstack max reducer clear cmdlist"); + clear_cmdlist->buffer_fill(adstack_max_reducer_output_buffer_->get_ptr(0), output_bytes, /*data=*/0); + clear_cmdlist->buffer_barrier(*adstack_max_reducer_output_buffer_); + device_->get_compute_stream()->submit_synced(clear_cmdlist.get()); + + auto [cmdlist, cmdlist_res] = device_->get_compute_stream()->new_command_list_unique(); + QD_ASSERT_INFO(cmdlist_res == RhiResult::success, "Failed to create adstack max reducer cmdlist"); + // Mirror `adstack_sizer_launch.cpp`'s residency hint so Metal's PSB load path sees ndarray data buffers as + // resident; without `track_physical_buffer` the Apple GPU returns zero / lower-32-bits-of-pointer garbage for every + // `kExternalTensorRead` body load. The same hint covers `kFieldLoad` body leaves: the SNode tree root buffers used + // by the FieldLoad PSB read path are also referenced via raw `bufferDeviceAddress` and need an explicit + // `useResource:` hint on Apple Silicon. Called once per cmdlist (before the per-spec dispatches). + if (device_->get_caps().get(DeviceCapability::spirv_has_physical_storage_buffer)) { + for (const auto &[arg_id, alloc] : ndarray_allocs) { + cmdlist->track_physical_buffer(alloc); + } + for (const auto &root_buffer : root_buffers_) { + if (root_buffer != nullptr) { + cmdlist->track_physical_buffer(*root_buffer); + } + } + } + for (size_t i = 0; i < level_dispatch.size(); ++i) { + const size_t k = level_dispatch[i]; + auto bindings = device_->create_resource_set_unique(); + if (args_buffer != nullptr) { + bindings->rw_buffer(0, *args_buffer); + } else { + bindings->rw_buffer(0, *adstack_max_reducer_args_placeholder_buffer_); + } + bindings->rw_buffer(1, *adstack_max_reducer_output_buffer_); + bindings->rw_buffer(2, adstack_max_reducer_params_buffer_->get_ptr(per_spec_params_offsets[i]), + params_size_bytes); + bindings->rw_buffer(3, *adstack_max_reducer_bytecode_buffer_); + + cmdlist->bind_pipeline(adstack_max_reducer_pipeline_.get()); + RhiResult bind_res = cmdlist->bind_shader_resources(bindings.get()); + QD_ERROR_IF(bind_res != RhiResult::success, "adstack max reducer resource binding error: RhiResult({})", + int(bind_res)); + + // Each thread walks `kElementsPerThread` elements via a strided loop inside the shader; cap workgroup count well + // below the Vulkan / Metal `maxComputeWorkGroupCount[0]` minimum (65535). Keep in sync with the shader. + constexpr uint32_t kElementsPerThread = 64u; + constexpr uint32_t kMaxWorkgroupCountX = 65535u; + const uint32_t threads_per_workgroup = spirv::kAdStackMaxReducerWorkgroupSize; + const uint32_t elements_per_workgroup = threads_per_workgroup * kElementsPerThread; + uint32_t group_x = (pending[k].length + elements_per_workgroup - 1) / elements_per_workgroup; + if (group_x > kMaxWorkgroupCountX) + group_x = kMaxWorkgroupCountX; + if (group_x == 0) { + // Empty range; record 0 directly. RHI rejects 0x1x1 dispatches on most backends. + result[pending[k].cache_key] = 0; + pending[k].dispatched = true; + ++dispatched_count; + continue; + } + RhiResult dispatch_res = cmdlist->dispatch(group_x, 1, 1); + QD_ERROR_IF(dispatch_res != RhiResult::success, "adstack max reducer dispatch error: RhiResult({})", + int(dispatch_res)); + cmdlist->buffer_barrier(*adstack_max_reducer_output_buffer_); + } + device_->get_compute_stream()->submit_synced(cmdlist.get()); + + // Read back this level's output slots: `slots[2*i]` = u32 max for `level_dispatch[i]`, `slots[2*i + 1]` = overflow + // flag. Overflow specs fall back to direct host-eval over the captured MOR node (against the substituted tree, so + // already-resolved deps' values are folded in). Cache misses get recorded with their body read observations so the + // next launch can short-circuit on a generation match. + void *mapped = nullptr; + RhiResult map_res = device_->map(*adstack_max_reducer_output_buffer_, &mapped); + QD_ASSERT_INFO(map_res == RhiResult::success, "Failed to map adstack max reducer output buffer for readback"); + const uint32_t *slots = reinterpret_cast(mapped); + for (size_t i = 0; i < level_dispatch.size(); ++i) { + const size_t k = level_dispatch[i]; + if (pending[k].dispatched) + continue; // empty-range short-circuit handled above + const uint32_t value_u32 = slots[2 * i]; + const uint32_t overflow_flag = slots[2 * i + 1]; + int64_t v; + if (overflow_flag != 0) { + const SerializedSizeExpr substituted = substitute_precomputed_max_over_range( + *pending[k].expr, pending[k].registry_id, pending[k].stack_id, result); + v = evaluate_adstack_size_expr_at_node(substituted, pending[k].mor_node_idx, prog, &host_ctx); + if (v < 0) + v = 0; + } else { + v = static_cast(value_u32); + } + result[pending[k].cache_key] = v; + if (cache != nullptr) { + populate_max_reducer_body_observations(pending[k].reads, &host_ctx, cache); + cache->record_max_reducer_eval(pending[k].registry_id, pending[k].stack_id, pending[k].mor_node_idx, v, + std::move(pending[k].reads)); + } + pending[k].dispatched = true; + ++dispatched_count; + } + device_->unmap(*adstack_max_reducer_output_buffer_); + } + + return result; +} + +} // namespace gfx +} // namespace quadrants::lang diff --git a/quadrants/runtime/gfx/adstack_sizer_launch.cpp b/quadrants/runtime/gfx/adstack_sizer_launch.cpp index ba88c8bca4..1b5f94cdab 100644 --- a/quadrants/runtime/gfx/adstack_sizer_launch.cpp +++ b/quadrants/runtime/gfx/adstack_sizer_launch.cpp @@ -103,15 +103,17 @@ void eval_per_task_metadata_on_host(const std::vector &adstack_task_indi const std::vector &task_attribs, Program *prog, LaunchContextBuilder &host_ctx, - std::vector &per_task_ad_stack) { + std::vector &per_task_ad_stack, + const MaxReducerResultMap &max_reducer_results) { using HeapKind = spirv::TaskAttributes::AdStackAllocaAttribs::HeapKind; // Span the per-task `evaluate_adstack_size_expr` calls below with one shared read cache. SizeExprLaunchScope launch_scope; for (size_t ti : adstack_task_indices) { const auto &allocas = task_attribs[ti].ad_stack.allocas; + const uint32_t registry_id = task_attribs[ti].ad_stack.registry_id; auto &rt = per_task_ad_stack[ti]; const size_t n_stacks = allocas.size(); - rt.metadata.assign(2 + 2 * n_stacks, 0); + rt.metadata.assign(3 + 2 * n_stacks, 0); // trailing slot is the overflow-flag the on-device sizer writes uint32_t running_off_f = 0; uint32_t running_off_i = 0; for (size_t i = 0; i < n_stacks; ++i) { @@ -122,7 +124,20 @@ void eval_per_task_metadata_on_host(const std::vector &adstack_task_indi // Match the shader's `max(max_size_compile_time, 1)` lower clamp. max_size = std::max(a.max_size_compile_time, 1u); } else { - int64_t evaluated = evaluate_adstack_size_expr(a.size_expr, prog, &host_ctx); + // Substitute any captured `MaxOverRange` whose result the max-reducer dispatched into a `Const` before the host + // evaluator walks the tree. The substituted tree is a stack-local that cannot be used as a stable cache key, so + // the substitution branch routes through `evaluate_adstack_size_expr_no_cache`; the empty-results fast path + // keeps the live `a.size_expr` reference and the cache stays warm. The non-cache branch's per-launch eval cost + // is small (a single tree walk dominated by `ExternalTensorRead` PSB dereferences); the dispatch the + // substitution feeds off was the dominant cost in the first place. + int64_t evaluated; + if (max_reducer_results.empty()) { + evaluated = evaluate_adstack_size_expr(a.size_expr, prog, &host_ctx); + } else { + const SerializedSizeExpr substituted = substitute_precomputed_max_over_range( + a.size_expr, registry_id, static_cast(i), max_reducer_results); + evaluated = evaluate_adstack_size_expr_no_cache(substituted, prog, &host_ctx); + } // `evaluate_adstack_size_expr` returns -1 only when `expr.nodes` is empty (handled above) or hits // an internal hard error; clamp to the same `max(_, 1)` lower bound the shader applies. if (evaluated < 1) { @@ -156,7 +171,8 @@ std::vector GfxRuntime::publish_adstack_metadata_spirv( DeviceAllocationGuard *args_buffer, const std::unordered_map &ndarray_allocs, const std::vector &task_attribs, - const std::string &kernel_name) { + const std::string &kernel_name, + const MaxReducerResultMap &max_reducer_results) { std::vector per_task_ad_stack(task_attribs.size()); for (size_t ti = 0; ti < task_attribs.size(); ++ti) { per_task_ad_stack[ti].stride_float = task_attribs[ti].ad_stack.per_thread_stride_float_compile_time; @@ -178,6 +194,16 @@ std::vector GfxRuntime::publish_adstack_metadata_spirv( "encode AdStack SizeExpr bytecode. Ensure GfxProgramImpl passes `program_impl = this` " "into `GfxRuntime::Params`."); + // Reverse-mode autodiff with adstacks requires Vulkan 1.3 (or Metal at MTLArgumentBuffersTier::Tier2) on this device. + // Older drivers cannot run the sizer paths correctly; the per-helper cap gates downstream + // (`dispatch_adstack_bound_reducers`, `dispatch_max_reducers`) rely on this single check and skip their own. + QD_ERROR_IF(!device_->get_caps().get(DeviceCapability::spirv_has_physical_storage_buffer), + "Reverse-mode autodiff with adstacks needs Vulkan 1.3 (or Metal Argument Buffers Tier 2); this " + "device does not advertise `spirv_has_physical_storage_buffer`."); + QD_ERROR_IF(!device_->get_caps().get(DeviceCapability::spirv_has_int64), + "Reverse-mode autodiff with adstacks needs Vulkan 1.3 (or Metal Argument Buffers Tier 2); this " + "device does not advertise `spirv_has_int64`."); + // Register each adstack-bearing task with the Program-side identity registry so the host raise site // can name the offending kernel + task in its diagnostic message. Idempotent: re-registration of the // same `&task_attribs[ti].ad_stack` returns the same id and just refreshes the metadata. The @@ -259,7 +285,7 @@ std::vector GfxRuntime::publish_adstack_metadata_spirv( // memory) still need the on-device sizer below. if (all_size_exprs_host_resolvable(adstack_task_indices, task_attribs)) { eval_per_task_metadata_on_host(adstack_task_indices, task_attribs, program_impl_->program, host_ctx, - per_task_ad_stack); + per_task_ad_stack, max_reducer_results); return per_task_ad_stack; } @@ -365,11 +391,11 @@ std::vector GfxRuntime::publish_adstack_metadata_spirv( SizeExprLaunchScope launch_scope; for (size_t k = 0; k < adstack_task_indices.size(); ++k) { size_t ti = adstack_task_indices[k]; - per_task_bytecodes[k] = encode_adstack_size_expr_device_bytecode_for_spirv(task_attribs[ti].ad_stack, - program_impl_->program, &host_ctx); + per_task_bytecodes[k] = encode_adstack_size_expr_device_bytecode_for_spirv( + task_attribs[ti].ad_stack, program_impl_->program, &host_ctx, max_reducer_results); per_task_bytecode_offsets[k] = align_up(total_bytecode_bytes, kDescriptorOffsetAlignment); total_bytecode_bytes = per_task_bytecode_offsets[k] + per_task_bytecodes[k].size(); - per_task_metadata_bytes[k] = (2u + 2u * task_attribs[ti].ad_stack.allocas.size()) * sizeof(uint32_t); + per_task_metadata_bytes[k] = (3u + 2u * task_attribs[ti].ad_stack.allocas.size()) * sizeof(uint32_t); } // Grow the shared bytecode scratch buffer if the concatenated blob outgrew it. Amortised doubling so @@ -522,6 +548,19 @@ std::vector GfxRuntime::publish_adstack_metadata_spirv( "bytecode_for_spirv` (wrong `kNodeOffArgBufferOffset` or missing `ExternalTensorRead` " "pre-substitution) or in the sizer shader's PSB read path, not a legitimate workload.", rt.stride_float, rt.stride_int, kMaxSaneStridePerThread); + // Cap-hit tripwire. The on-device sizer writes 1 into the trailing overflow-flag slot when it observes a + // `MaxOverRange` whose iteration count exceeds the `1<<24` cap; the hard error here surfaces the failure at + // `qd.sync()` with a clean attribution. Recognized `MaxOverRange` shapes are dispatched in parallel by the + // max-reducer and substituted to `Const` before the sizer interpreter sees them, so this path is reachable only for + // out-of-grammar shapes; broadening the recognizer grammar moves more shapes onto the loud path automatically. + const size_t overflow_word_idx = 2u + 2u * task_attribs[ti].ad_stack.allocas.size(); + QD_ERROR_IF(overflow_word_idx < rt.metadata.size() && rt.metadata[overflow_word_idx] != 0, + "Adstack on-device sizer hit a `MaxOverRange` whose iteration count exceeds the {} cap. The recognized " + "grammar's max-reducer dispatch did not capture this shape so the substitution path could not pre-fold " + "the `MaxOverRange` to a `Const`. Restructure the source kernel to fit the recognizer grammar (single " + "bound variable per body, body limited to `Const` / `ExternalTensorRead(arg, [BoundVariable])` / `Add` " + "/ `Sub` / `Mul` / `Max`), or shrink the enclosing reverse-mode loop's iteration count below the cap.", + int64_t{1} << 24); ctx_buffers_.push_back(std::move(per_task_metadata_allocs[k])); } diff --git a/quadrants/runtime/gfx/runtime.cpp b/quadrants/runtime/gfx/runtime.cpp index 6d9f09b303..1e92ba8207 100644 --- a/quadrants/runtime/gfx/runtime.cpp +++ b/quadrants/runtime/gfx/runtime.cpp @@ -540,13 +540,18 @@ void GfxRuntime::launch_kernel(KernelHandle handle, LaunchContextBuilder &host_c ti_kernel->ti_kernel_attribs().ctx_attribs.arr_access); } + // Max-reducer dispatch. Must precede `publish_adstack_metadata_spirv` so the per-spec substitution lands before the + // sizer's tree walk. Implementation lives in `runtime/gfx/adstack_max_reducer_launch.cpp`; that file early- returns + // an empty map on kernels with no captured specs so the call below is cheap in the common case. + const auto max_reducer_results = dispatch_max_reducers(host_ctx, args_buffer.get(), any_arrays, task_attribs); + // Device-side adstack SizeExpr evaluation: every task with adstack allocas has its per-alloca `max_size` / // `offset` metadata resolved by a dedicated compute shader (see `quadrants/runtime/gfx/adstack_sizer_launch.cpp` // for the full mechanism). The helper internally early-returns (after seeding the per-task vector with // compile-time strides) when no task has adstack allocas, so forward-only kernels pay only the cheap pre-populate // pass; the actual sizer dispatch + `wait_idle()` only fires for reverse-mode kernels. std::vector per_task_ad_stack = publish_adstack_metadata_spirv( - host_ctx, args_buffer.get(), any_arrays, task_attribs, ti_kernel->ti_kernel_attribs().name); + host_ctx, args_buffer.get(), any_arrays, task_attribs, ti_kernel->ti_kernel_attribs().name, max_reducer_results); // Static-IR-bound sparse-adstack-heap reducer dispatch. Gated on whether any task in this kernel has a captured // `bound_expr` - the codegen routes such tasks through the lazy LCA-block atomic-rmw row claim that reads diff --git a/quadrants/runtime/gfx/runtime.h b/quadrants/runtime/gfx/runtime.h index dd804fa842..9dd1557174 100644 --- a/quadrants/runtime/gfx/runtime.h +++ b/quadrants/runtime/gfx/runtime.h @@ -11,6 +11,7 @@ #include "quadrants/struct/snode_tree.h" #include "quadrants/program/snode_expr_utils.h" #include "quadrants/program/program_impl.h" +#include "quadrants/program/adstack_size_expr_eval.h" #include "quadrants/program/kernel_launcher.h" namespace quadrants::lang { @@ -160,7 +161,8 @@ class QD_DLL_EXPORT GfxRuntime { DeviceAllocationGuard *args_buffer, const std::unordered_map &ndarray_allocs, const std::vector &task_attribs, - const std::string &kernel_name); + const std::string &kernel_name, + const quadrants::lang::MaxReducerResultMap &max_reducer_results = quadrants::lang::MaxReducerResultMap{}); // Static-IR-bound sparse-adstack-heap reducer dispatch. For each task with a captured ndarray-backed `bound_expr`, // dispatches the generic reducer compute shader (see `quadrants/codegen/spirv/adstack_bound_reducer_shader.{h,cpp}`) @@ -175,6 +177,20 @@ class QD_DLL_EXPORT GfxRuntime { DeviceAllocationGuard *args_buffer, const std::vector &task_attribs); + // Max-reducer dispatch. For each captured `StaticAdStackMaxReducerSpec` across every task in `task_attribs`, hits + // `AdStackCache::try_max_reducer_cache_hit` first; on miss dispatches `adstack_max_reducer_pipeline_` over `[0, + // length)` and atomic-SMaxes the body's per-thread result into the shared output buffer. The returned map is keyed by + // `(registry_id, stack_id, mor_node_idx)` packed via the same `AdStackCache` encoding so + // `substitute_precomputed_max_over_range` can substitute results into per-stack `SerializedSizeExpr` trees before the + // per-thread sizer or device sizer encoder walks them. Empty map on capability-missing devices or kernels with no + // captured specs (caller falls through to the existing capped path). Implementation lives in + // `runtime/gfx/adstack_max_reducer_launch.cpp`. + quadrants::lang::MaxReducerResultMap dispatch_max_reducers( + LaunchContextBuilder &host_ctx, + DeviceAllocationGuard *args_buffer, + const std::unordered_map &ndarray_allocs, + const std::vector &task_attribs); + void init_nonroot_buffers(); Device *device_{nullptr}; @@ -270,6 +286,23 @@ class QD_DLL_EXPORT GfxRuntime { // Metal / MoltenVK by the same RHI rule the slot-3 placeholder above guards against. std::unique_ptr adstack_bound_reducer_args_placeholder_buffer_; + // Max-reducer per-`GfxRuntime` plumbing. Built once on the first launch that contains a task with non-empty + // `max_reducer_specs`, reused across every such launch afterwards. Null on backends without + // `spirv_has_physical_storage_buffer + spirv_has_int64`; in that case the runtime falls back to the existing capped + // path on the per-thread sizer eval (silent truncation at `1<<24` on the device sizer side; user-visible bug surfaces + // only with `QD_DEBUG_ADSTACK=1`). The grow-on-demand buffers below hold per-spec params blobs (binding 2), the body + // bytecode payload (binding 3), and the per-spec output i64 slots (binding 1). Slot 0 is the kernel arg buffer. + std::unique_ptr adstack_max_reducer_pipeline_{nullptr}; + std::unique_ptr adstack_max_reducer_params_buffer_; + size_t adstack_max_reducer_params_buffer_size_{0}; + std::unique_ptr adstack_max_reducer_bytecode_buffer_; + size_t adstack_max_reducer_bytecode_buffer_size_{0}; + std::unique_ptr adstack_max_reducer_output_buffer_; + size_t adstack_max_reducer_output_buffer_size_{0}; + // Slot-0 placeholder buffer for kernels with no kernel arg buffer (SNode-only kernels with `args_buffer == null`). + // Same RHI rule as the bound-reducer's slot-0 placeholder: descriptor-set layouts require a non-null binding. + std::unique_ptr adstack_max_reducer_args_placeholder_buffer_; + // Per-kernel `BufferType::AdStackBoundRowCapacity` (`uint[num_tasks_in_kernel]`). Populated by the host after the // bound-reducer dispatch with each task's exact reducer count (UINT32_MAX for tasks without a captured captured // `bound_expr`, so the codegen-emitted defense-in-depth bounds check is inert on those). Bound to the main task on diff --git a/quadrants/runtime/llvm/CMakeLists.txt b/quadrants/runtime/llvm/CMakeLists.txt index 31d341e3b7..6d448d24d8 100644 --- a/quadrants/runtime/llvm/CMakeLists.txt +++ b/quadrants/runtime/llvm/CMakeLists.txt @@ -4,7 +4,9 @@ add_library(llvm_runtime) target_sources(llvm_runtime PRIVATE llvm_runtime_executor.cpp - llvm_adstack_lazy_claim.cpp + adstack_lazy_claim/bound_eval.cpp + adstack_lazy_claim/metadata_publish.cpp + adstack_lazy_claim/heap_grow.cpp llvm_context.cpp snode_tree_buffer_manager.cpp kernel_launcher.cpp diff --git a/quadrants/runtime/llvm/adstack_lazy_claim/bound_eval.cpp b/quadrants/runtime/llvm/adstack_lazy_claim/bound_eval.cpp new file mode 100644 index 0000000000..e71f48a004 --- /dev/null +++ b/quadrants/runtime/llvm/adstack_lazy_claim/bound_eval.cpp @@ -0,0 +1,765 @@ +// Stage A of the LLVM sparse-adstack-heap lazy-claim pipeline: per-launch buffer publish + bound-expression +// evaluation. See `adstack_lazy_claim/bound_eval.h` for the stage-level documentation. + +#include "quadrants/runtime/llvm/adstack_lazy_claim/bound_eval.h" + +#include "quadrants/runtime/llvm/llvm_runtime_executor.h" +#include "quadrants/program/adstack_size_expr_eval.h" +#include "quadrants/program/program.h" + +#include +#include +#include +#include + +#include "quadrants/ir/adstack_size_expr_device.h" +#include "quadrants/ir/static_adstack_bound_reducer_device.h" +#include "quadrants/ir/static_adstack_max_reducer_device.h" +#include "quadrants/ir/type_factory.h" +#include "quadrants/program/launch_context_builder.h" +#include "quadrants/program/program_impl.h" +#include "quadrants/rhi/llvm/llvm_device.h" + +#include "quadrants/platform/cuda/detect_cuda.h" +#include "quadrants/rhi/cuda/cuda_driver.h" + +#include "quadrants/platform/amdgpu/detect_amdgpu.h" +#include "quadrants/rhi/amdgpu/amdgpu_driver.h" + +namespace quadrants::lang { + +uint32_t LlvmRuntimeExecutor::publish_per_task_bound_count_cpu(std::size_t task_index, + const AdStackSizingInfo &ad_stack, + std::size_t length, + LaunchContextBuilder *ctx) { + // Default to UINT32_MAX (no clamp); only override on a successful host evaluation. The codegen-emitted bounds clamp + // at the float LCA-block claim site stays inert when the slot holds UINT32_MAX, so this fall-through is a no-op that + // preserves the existing behaviour. + if (config_.arch != Arch::x64 && config_.arch != Arch::arm64) { + return std::numeric_limits::max(); + } + if (!ad_stack.bound_expr.has_value()) { + return std::numeric_limits::max(); + } + const auto &be = ad_stack.bound_expr.value(); + + // Resolve the per-iteration field address. Two source kinds (mirrors the device-side reducer in + // `runtime_eval_static_bound_count`): + // * NdArray: walk `arg_buffer + data_ptr_byte_off` to fetch the ndarray's data pointer; the gating field + // is then `data_ptr[i]` for `i in [0, length)`. On CPU `arg_buffer` lives in host memory, so the deref is direct. + // * SNode: walk `runtime->roots[snode_root_id] + snode_byte_base_offset + i * snode_byte_cell_stride` + // for `i in [0, length)`. The byte offset / cell stride were resolved by the codegen-time SNode descriptor + // resolver (via `compile_snode_structs`); `runtime->roots` is host-resident on CPU and reachable through the + // `LLVMRuntime_get_roots` STRUCT_FIELD_ARRAY getter. + // Without the SNode arm, kernels with a captured SNode-backed bound_expr leave the capacity slot at UINT32_MAX (the + // `publish_adstack_lazy_claim_buffers` default), `ensure_per_task_float_heap_post_reducer` sizes the float heap at + // the worst-case num_threads count, and the codegen-emitted clamp goes inert -exactly the regression a `for i in + // selector: if selector[i] > eps:` SNode-gated reverse kernel hits when the float adstack heap can only hold + // `num_cpu_threads` rows but the LCA-block atomic-rmw fires once per gated iteration. + using FSK = StaticAdStackBoundExpr::FieldSourceKind; + if (be.field_source_kind != FSK::NdArray && be.field_source_kind != FSK::SNode) { + return std::numeric_limits::max(); + } + + const char *field_base = nullptr; + std::size_t field_stride_bytes = 0; + if (be.field_source_kind == FSK::NdArray) { + if (ctx == nullptr || ctx->args_type == nullptr || ctx->get_context().arg_buffer == nullptr) { + return std::numeric_limits::max(); + } + std::vector indices = be.ndarray_arg_id; + indices.push_back(TypeFactory::DATA_PTR_POS_IN_NDARRAY); + std::size_t data_ptr_byte_off = ctx->args_type->get_element_offset(indices); + const char *arg_buffer = static_cast(ctx->get_context().arg_buffer); + void *data_ptr = *reinterpret_cast(arg_buffer + data_ptr_byte_off); + if (data_ptr == nullptr) { + return std::numeric_limits::max(); + } + field_base = static_cast(data_ptr); + field_stride_bytes = be.field_dtype_is_double ? sizeof(double) : sizeof(int32_t); // f32 / i32 = 4 B, f64 = 8 B. + } else { + // SNode-backed source: query the host-resident `runtime->roots[snode_root_id]` pointer through the + // STRUCT_FIELD_ARRAY getter; on CPU this is an in-process call (no DtoH stage) and returns the dense root buffer + // base address directly. + if (be.snode_root_id < 0 || llvm_runtime_ == nullptr || result_buffer_cache_ == nullptr) { + return std::numeric_limits::max(); + } + // `RUNTIME_STRUCT_FIELD_ARRAY(LLVMRuntime, roots)` defines `runtime_LLVMRuntime_get_roots(LLVMRuntime *runtime, + // LLVMRuntime *s, int i)` (the macro takes a struct-of-interest argument distinct from the runtime context, but for + // fields of `LLVMRuntime` itself the two pointers are the same). `runtime_query` auto-prepends `llvm_runtime_` as + // the first arg, so we pass `(llvm_runtime_, root_id)` to make the call resolve to the 3-arg signature + // `(llvm_runtime_, llvm_runtime_, root_id)`. Mirrors the `node_allocators` call site a few hundred lines above. + void *root_ptr = + runtime_query("LLVMRuntime_get_roots", result_buffer_cache_, llvm_runtime_, be.snode_root_id); + if (root_ptr == nullptr) { + return std::numeric_limits::max(); + } + field_base = static_cast(root_ptr) + be.snode_byte_base_offset; + field_stride_bytes = static_cast(be.snode_byte_cell_stride); + } + + // Walk `[0, length)` evaluating the captured predicate on each thread's `field[i]`. The polarity bit selects + // enter-on-true vs enter-on-false at the LCA's IfStmt; the count we publish is always the number of threads that + // REACH the LCA, regardless of the gate orientation. f64 gates dispatch through the same float-source arm but read + // the source as `double*` and compare against `literal_f64` so the f64 precision the user declared is preserved + // end-to-end (narrowing the literal to f32 here would risk false-positive / negative counts on gates whose threshold + // sits within the f32 representable gap). + uint32_t count = 0; + if (be.field_dtype_is_float) { + if (be.field_dtype_is_double) { + for (std::size_t i = 0; i < length; ++i) { + const double v = *reinterpret_cast(field_base + i * field_stride_bytes); + const bool match = eval_cmp(be.cmp_op, v, be.literal_f64); + if (be.polarity ? match : !match) { + ++count; + } + } + } else { + for (std::size_t i = 0; i < length; ++i) { + const float v = *reinterpret_cast(field_base + i * field_stride_bytes); + const bool match = eval_cmp(be.cmp_op, v, be.literal_f32); + if (be.polarity ? match : !match) { + ++count; + } + } + } + } else { + for (std::size_t i = 0; i < length; ++i) { + const int32_t v = *reinterpret_cast(field_base + i * field_stride_bytes); + const bool match = eval_cmp(be.cmp_op, v, be.literal_i32); + if (be.polarity ? match : !match) { + ++count; + } + } + } + + // Publish the count into `runtime->adstack_bound_row_capacities[task_index]` so the codegen-emitted bounds clamp at + // the float LCA-block claim site reads it back as the per-task capacity. Slot was reset to UINT32_MAX by + // `publish_adstack_lazy_claim_buffers`; this overwrite tightens it to the real count. + if (runtime_adstack_bound_row_capacities_field_ptr_ == nullptr || adstack_bound_row_capacities_alloc_ == nullptr) { + return count; + } + void *bound_capacities_dev_ptr = get_device_alloc_info_ptr(*adstack_bound_row_capacities_alloc_); + // CPU only: write directly into the host-resident array. + uint32_t *slots = static_cast(bound_capacities_dev_ptr); + slots[task_index] = count; + return count; +} + +void LlvmRuntimeExecutor::publish_per_task_bound_count_device(std::size_t task_index, + const AdStackSizingInfo &ad_stack, + std::size_t length, + LaunchContextBuilder *ctx, + void *device_runtime_context_ptr) { + // Only fires for CUDA / AMDGPU; CPU goes through `publish_per_task_bound_count_cpu`. Bail when the task did not + // capture a bound_expr (no clamp needed - the slot stays at the UINT32_MAX default that + // `publish_adstack_lazy_claim_buffers` wrote). Both ndarray and SNode source kinds are dispatched through the same + // params blob; the device-side reducer selects between them via `field_source_is_snode`. + if (config_.arch != Arch::cuda && config_.arch != Arch::amdgpu) { + return; + } + if (!ad_stack.bound_expr.has_value()) { + return; + } + const auto &be = ad_stack.bound_expr.value(); + const bool is_snode_source = be.field_source_kind == StaticAdStackBoundExpr::FieldSourceKind::SNode; + if (ctx == nullptr || ctx->args_type == nullptr) { + return; + } + const uint32_t cmp_op_encoded = encode_cmp_op_for_llvm_reducer(be.cmp_op); + if (cmp_op_encoded == std::numeric_limits::max()) { + return; // unrecognised comparison op (the IR pattern matcher should have rejected it earlier) + } + + // Fill the device-side params struct on the host. Threshold bits live as the same u32 the runtime function bitcasts + // back; we copy whichever underlying integer or float value the analysis captured. The two source shapes (ndarray + + // SNode) share the comparison fields and differ only in which trailing fields the reducer reads (`arg_word_offset` + // for ndarray, `snode_root_id` + `snode_byte_*` for SNode); host-side we populate the matching pair and zero out the + // other. + LlvmAdStackBoundReducerDeviceParams params{}; + params.task_index = static_cast(task_index); + params.length = static_cast(is_snode_source ? be.snode_iter_count : length); + params.cmp_op = cmp_op_encoded; + params.field_dtype_is_float = be.field_dtype_is_float ? 1u : 0u; + params.field_dtype_is_double = be.field_dtype_is_double ? 1u : 0u; + params.polarity = be.polarity ? 1u : 0u; + if (be.field_dtype_is_double) { + // Pack the f64 threshold's 64-bit pattern into the (lo, hi) u32 pair the reducer reassembles. + uint64_t bits64 = 0; + std::memcpy(&bits64, &be.literal_f64, sizeof(uint64_t)); + params.threshold_bits = static_cast(bits64 & 0xFFFFFFFFu); + params.threshold_bits_high = static_cast(bits64 >> 32); + } else if (be.field_dtype_is_float) { + std::memcpy(¶ms.threshold_bits, &be.literal_f32, sizeof(uint32_t)); + } else { + params.threshold_bits = static_cast(be.literal_i32); + } + params.field_source_is_snode = is_snode_source ? 1u : 0u; + if (is_snode_source) { + params.arg_word_offset = 0; + params.snode_root_id = static_cast(be.snode_root_id); + params.snode_byte_base_offset = be.snode_byte_base_offset; + params.snode_byte_cell_stride = be.snode_byte_cell_stride; + } else { + // Resolve the ndarray data pointer's word offset within the kernel arg buffer. Same path the SPIR-V reducer and the + // CPU host-eval use; bytes -> words for the reducer's `arg_buffer_u32[arg_word_offset]` indexing. + std::vector indices = be.ndarray_arg_id; + indices.push_back(TypeFactory::DATA_PTR_POS_IN_NDARRAY); + std::size_t data_ptr_byte_off = ctx->args_type->get_element_offset(indices); + if (data_ptr_byte_off % sizeof(uint32_t) != 0) { + return; // misaligned offset; the reducer's u32-word indexing would lose bits. + } + params.arg_word_offset = static_cast(data_ptr_byte_off / sizeof(uint32_t)); + params.snode_root_id = 0; + params.snode_byte_base_offset = 0; + params.snode_byte_cell_stride = 0; + } + + // Lazy-allocate the device-side params scratch buffer the first time a bound_expr task fires; reuse for subsequent + // tasks across kernels. Sized for one struct (the reducer is single-task per call); a future optimisation could pack + // multiple tasks' params into one buffer and dispatch them in a single launch. + const std::size_t needed_bytes = sizeof(LlvmAdStackBoundReducerDeviceParams); + if (needed_bytes > adstack_bound_reducer_params_capacity_) { + Device::AllocParams alloc_params{}; + alloc_params.size = std::max(needed_bytes, 2 * adstack_bound_reducer_params_capacity_); + alloc_params.host_read = false; + alloc_params.host_write = true; + alloc_params.export_sharing = false; + alloc_params.usage = AllocUsage::Storage; + DeviceAllocation new_alloc; + RhiResult res = llvm_device()->allocate_memory(alloc_params, &new_alloc); + QD_ERROR_IF(res != RhiResult::success, + "Failed to allocate {} bytes for adstack bound reducer params buffer (err: {})", alloc_params.size, + int(res)); + adstack_bound_reducer_params_alloc_ = std::make_unique(std::move(new_alloc)); + adstack_bound_reducer_params_capacity_ = alloc_params.size; + } + void *params_dev_ptr = get_device_alloc_info_ptr(*adstack_bound_reducer_params_alloc_); + + // h2d the params struct into the device buffer. + if (config_.arch == Arch::cuda) { +#if defined(QD_WITH_CUDA) + CUDADriver::get_instance().memcpy_host_to_device(params_dev_ptr, ¶ms, needed_bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } else if (config_.arch == Arch::amdgpu) { +#if defined(QD_WITH_AMDGPU) + AMDGPUDriver::get_instance().memcpy_host_to_device(params_dev_ptr, ¶ms, needed_bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } + + // Dispatch the runtime reducer function: single-threaded device-side walk that reads `ctx->arg_buffer` (the + // device-mirror the launcher staged) and writes the count into `runtime->adstack_bound_row_capacities[task_index]`. + // Pass the device-side `RuntimeContext` pointer the same way the size-expr sizer does so the function can deref + // `ctx->arg_buffer` on-device. + auto *const runtime_jit = get_runtime_jit_module(); + void *runtime_context_ptr_for_reducer = + device_runtime_context_ptr != nullptr ? device_runtime_context_ptr : static_cast(&ctx->get_context()); + runtime_jit->call("runtime_eval_static_bound_count", llvm_runtime_, + runtime_context_ptr_for_reducer, params_dev_ptr); +} + +std::unordered_map LlvmRuntimeExecutor::dispatch_max_reducers_for_tasks( + const std::vector &tasks, + LaunchContextBuilder *ctx, + void *device_runtime_context_ptr) { + std::vector ad_stacks_view; + ad_stacks_view.reserve(tasks.size()); + for (const auto &t : tasks) { + ad_stacks_view.push_back(t.ad_stack); + } + return dispatch_max_reducers_for_tasks(ad_stacks_view, ctx, device_runtime_context_ptr); +} + +std::unordered_map LlvmRuntimeExecutor::dispatch_max_reducers_for_tasks( + const std::vector &ad_stacks, + LaunchContextBuilder *ctx, + void *device_runtime_context_ptr) { + using quadrants::lang::AdStackSizeExprDeviceNode; + using quadrants::lang::EncodedMaxReducerBody; + using quadrants::lang::LlvmAdStackMaxReducerDeviceParams; + using quadrants::lang::MaxReducerResultMap; + + // Reset the per-launch transient before every dispatch so a kernel with no captured specs sees an empty map. + current_max_reducer_results_.clear(); + MaxReducerResultMap result; + if (ctx == nullptr || ctx->args_type == nullptr) { + fprintf(stderr, "[trace dispatch_max_reducers] ENTER ctx=%p args_type=%p -> early return (no ctx)\n", (void *)ctx, + (void *)(ctx ? ctx->args_type : nullptr)); + fflush(stderr); + return result; + } + Program *prog = (program_impl_ != nullptr) ? program_impl_->program : nullptr; + AdStackCache *cache = (prog != nullptr) ? &prog->adstack_cache() : nullptr; + fprintf(stderr, + "[trace dispatch_max_reducers] ENTER arch=%d ad_stacks.size()=%zu device_runtime_ctx=%p host_ctx=%p " + "llvm_runtime=%p\n", + (int)config_.arch, ad_stacks.size(), device_runtime_context_ptr, static_cast(&ctx->get_context()), + llvm_runtime_); + fflush(stderr); + + // Pass 1: per-spec cache lookup. Hits drop straight into `result`; misses go to pending with back-refs to the + // source `SerializedSizeExpr` and `StaticAdStackMaxReducerSpec`. Host-evaluation of begin / end and body bytecode + // encoding is deferred to the per-level prepare step below, where each spec's `dependent_mor_node_idxs` have + // already been substituted into the working tree. Mirrors the gfx variant's level-based dispatch so a captured + // outer `MaxOverRange` whose end references a captured inner `MaxOverRange` resolves through the inner's parallel + // dispatch instead of host-walking it (which would either trip the host evaluator's `1<<24` cap or read garbage + // through device-resident ndarray buffers on launchers that do not host-accessibilise their data pointers). + struct PendingMaxReducerDispatch { + uint64_t cache_key; + uint32_t registry_id; + int32_t stack_id; + int32_t mor_node_idx; + const SerializedSizeExpr *expr; + const StaticAdStackMaxReducerSpec *spec; + bool dispatched{false}; + bool dropped{false}; + LlvmAdStackMaxReducerDeviceParams params; + std::vector body_bytecode; + std::vector reads; + }; + std::vector pending; + for (std::size_t ti = 0; ti < ad_stacks.size(); ++ti) { + const auto &ad_stack = ad_stacks[ti]; + if (ad_stack.max_reducer_specs.empty()) { + continue; + } + const uint32_t registry_id = ad_stack.registry_id; + if (registry_id == 0) { + continue; + } + for (const auto &spec : ad_stack.max_reducer_specs) { + const uint64_t key = (static_cast(registry_id) & 0xFFFFFFFFull) | + ((static_cast(spec.stack_id) & 0xFFFFull) << 32) | + ((static_cast(spec.mor_node_idx) & 0xFFFFull) << 48); + if (cache != nullptr) { + int64_t cached; + if (cache->try_max_reducer_cache_hit(registry_id, spec.stack_id, spec.mor_node_idx, ctx, cached)) { + result[key] = cached; + continue; + } + } + PendingMaxReducerDispatch p{}; + p.cache_key = key; + p.registry_id = registry_id; + p.stack_id = spec.stack_id; + p.mor_node_idx = spec.mor_node_idx; + p.expr = &ad_stack.size_exprs[spec.stack_id]; + p.spec = &spec; + fprintf(stderr, + "[trace dispatch_max_reducers] miss reg=%u stack=%d mor=%d cache_key=0x%016llx num_axes=%zu deps=%zu\n", + registry_id, spec.stack_id, spec.mor_node_idx, (unsigned long long)key, spec.axis_var_ids.size(), + spec.dependent_mor_node_idxs.size()); + fflush(stderr); + pending.push_back(std::move(p)); + } + } + fprintf(stderr, "[trace dispatch_max_reducers] pass1 done, pending=%zu cache_hits=%zu\n", pending.size(), + result.size()); + fflush(stderr); + if (pending.empty()) { + return result; + } + + // Lazy-resolve the runtime-field address for `adstack_max_reducer_outputs` once per program lifetime. + if (runtime_adstack_max_reducer_outputs_field_ptr_ == nullptr) { + fprintf(stderr, "[trace dispatch_max_reducers] resolving runtime_adstack_max_reducer_outputs_field_ptr_\n"); + fflush(stderr); + auto *const runtime_jit = get_runtime_jit_module(); + runtime_jit->call("runtime_get_adstack_max_reducer_field_ptr", llvm_runtime_); + runtime_adstack_max_reducer_outputs_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id, result_buffer_cache_)); + fprintf(stderr, "[trace dispatch_max_reducers] field_ptr=%p\n", runtime_adstack_max_reducer_outputs_field_ptr_); + fflush(stderr); + } + + auto copy_h2d = [&](void *dst, const void *src, std::size_t bytes) { + fprintf(stderr, "[trace copy_h2d] dst=%p src=%p bytes=%zu arch=%d\n", dst, src, bytes, (int)config_.arch); + fflush(stderr); + if (config_.arch == Arch::cuda) { +#if defined(QD_WITH_CUDA) + CUDADriver::get_instance().memcpy_host_to_device(dst, const_cast(src), bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } else if (config_.arch == Arch::amdgpu) { +#if defined(QD_WITH_AMDGPU) + AMDGPUDriver::get_instance().memcpy_host_to_device(dst, const_cast(src), bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } else { + std::memcpy(dst, src, bytes); + } + fprintf(stderr, "[trace copy_h2d] done\n"); + fflush(stderr); + }; + auto copy_d2h = [&](void *dst, const void *src, std::size_t bytes) { + fprintf(stderr, "[trace copy_d2h] dst=%p src=%p bytes=%zu arch=%d\n", dst, src, bytes, (int)config_.arch); + fflush(stderr); + if (config_.arch == Arch::cuda) { +#if defined(QD_WITH_CUDA) + CUDADriver::get_instance().memcpy_device_to_host(dst, const_cast(src), bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } else if (config_.arch == Arch::amdgpu) { +#if defined(QD_WITH_AMDGPU) + AMDGPUDriver::get_instance().memcpy_device_to_host(dst, const_cast(src), bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } else { + std::memcpy(dst, src, bytes); + } + fprintf(stderr, "[trace copy_d2h] done\n"); + fflush(stderr); + }; + + auto *const runtime_jit = get_runtime_jit_module(); + void *runtime_context_ptr_for_reducer = + device_runtime_context_ptr != nullptr ? device_runtime_context_ptr : static_cast(&ctx->get_context()); + + auto arg_buffer_offset_resolver = [&](const std::vector &arg_id_path) -> int32_t { + std::vector path(arg_id_path.begin(), arg_id_path.end()); + path.push_back(TypeFactory::DATA_PTR_POS_IN_NDARRAY); + const std::size_t byte_off = ctx->args_type->get_element_offset(path); + if (byte_off > std::numeric_limits::max()) { + return -1; + } + return static_cast(byte_off); + }; + + // Level-based dispatch: each iteration processes the specs whose `dependent_mor_node_idxs` are all in `result` + // (cache hits + earlier rounds), substitutes those values into the working tree, host-evaluates begin / end, + // encodes body bytecode, then dispatches the round one spec at a time (the LLVM runtime function is single- + // threaded; batching is per-round only at the spec-prep level). Most kernels finish in one round; nested patterns + // (e.g. outer MaxOverRange whose end contains a captured inner max-of-array) take one round per dependency depth. + // No-progress rounds drop the remaining specs and let the per-task sizer's loud-error path absorb them. + std::size_t dispatched_count = 0; + std::size_t dropped_count = 0; + std::size_t round_idx = 0; + while (dispatched_count + dropped_count < pending.size()) { + fprintf(stderr, "[trace dispatch_max_reducers] ROUND %zu start (dispatched=%zu dropped=%zu pending=%zu)\n", + round_idx, dispatched_count, dropped_count, pending.size()); + fflush(stderr); + std::vector level_indices; + for (std::size_t k = 0; k < pending.size(); ++k) { + if (pending[k].dispatched || pending[k].dropped) + continue; + bool deps_ok = true; + for (int32_t dep_node : pending[k].spec->dependent_mor_node_idxs) { + const uint64_t dep_key = (static_cast(pending[k].registry_id) & 0xFFFFFFFFull) | + ((static_cast(pending[k].stack_id) & 0xFFFFull) << 32) | + ((static_cast(dep_node) & 0xFFFFull) << 48); + if (result.find(dep_key) == result.end()) { + deps_ok = false; + break; + } + } + if (deps_ok) + level_indices.push_back(k); + } + fprintf(stderr, "[trace dispatch_max_reducers] ROUND %zu level_indices.size()=%zu\n", round_idx, + level_indices.size()); + fflush(stderr); + if (level_indices.empty()) { + fprintf(stderr, "[trace dispatch_max_reducers] ROUND %zu no progress; dropping remaining %zu specs\n", round_idx, + pending.size() - dispatched_count - dropped_count); + fflush(stderr); + for (std::size_t k = 0; k < pending.size(); ++k) { + if (!pending[k].dispatched && !pending[k].dropped) { + pending[k].dropped = true; + ++dropped_count; + } + } + break; + } + + // Prepare each ready spec: substitute already-resolved deps' values into the tree, host-eval begin / end, encode + // body bytecode. Specs whose preparation fails (axis non-resolvable, length over u32 cap, body grammar reject) + // mark `dropped` and are skipped for this round and forever. + std::vector level_dispatch; + level_dispatch.reserve(level_indices.size()); + for (std::size_t k : level_indices) { + const auto *spec = pending[k].spec; + const std::size_t num_axes = spec->axis_var_ids.size(); + if (num_axes == 0 || num_axes > static_cast(kAdStackMaxReducerMaxAxes)) { + pending[k].dropped = true; + ++dropped_count; + continue; + } + const SerializedSizeExpr substituted = + substitute_precomputed_max_over_range(*pending[k].expr, pending[k].registry_id, pending[k].stack_id, result); + std::vector per_axis_begin(num_axes, 0); + std::vector per_axis_length(num_axes, 0); + bool axes_ok = true; + uint64_t total_length = 1; + for (std::size_t a = 0; a < num_axes; ++a) { + const int64_t bv = evaluate_adstack_size_expr_at_node(substituted, spec->axis_begin_node_idxs[a], prog, ctx); + const int64_t ev = evaluate_adstack_size_expr_at_node(substituted, spec->axis_end_node_idxs[a], prog, ctx); + if (bv < 0 || ev < 0 || ev <= bv) { + axes_ok = false; + break; + } + per_axis_begin[a] = bv; + per_axis_length[a] = ev - bv; + total_length *= static_cast(per_axis_length[a]); + if (total_length > std::numeric_limits::max()) { + axes_ok = false; + break; + } + } + if (!axes_ok) { + pending[k].dropped = true; + ++dropped_count; + continue; + } + EncodedMaxReducerBody encoded = encode_max_reducer_body_bytecode( + substituted, spec->body_node_idx, spec->axis_var_ids, arg_buffer_offset_resolver, ctx, prog); + if (encoded.body_node_count == 0) { + pending[k].dropped = true; + ++dropped_count; + continue; + } + pending[k].params = LlvmAdStackMaxReducerDeviceParams{}; + // `output_slot` is assigned the round-local index after `level_dispatch` is finalised, just before the dispatch + // loop below; this matches the gfx launcher's per-round output-buffer reuse pattern. + pending[k].params.num_axes = static_cast(num_axes); + pending[k].params.body_node_count = encoded.body_node_count; + pending[k].params.body_root_node_idx = static_cast(encoded.body_node_count) - 1; + for (std::size_t a = 0; a < num_axes; ++a) { + pending[k].params.per_axis_length[a] = static_cast(per_axis_length[a]); + pending[k].params.per_axis_begin[a] = per_axis_begin[a]; + pending[k].params.per_axis_var_id[a] = static_cast(a); + } + pending[k].body_bytecode = std::move(encoded.bytes); + pending[k].reads = std::move(encoded.body_reads); + fprintf(stderr, + "[trace dispatch_max_reducers] prep[k=%zu] reg=%u stack=%d mor=%d num_axes=%zu total_length=%llu " + "body_node_count=%u body_bytecode_size=%zu\n", + k, pending[k].registry_id, pending[k].stack_id, pending[k].mor_node_idx, num_axes, + (unsigned long long)total_length, encoded.body_node_count, pending[k].body_bytecode.size()); + fflush(stderr); + level_dispatch.push_back(k); + } + fprintf(stderr, "[trace dispatch_max_reducers] ROUND %zu level_dispatch.size()=%zu\n", round_idx, + level_dispatch.size()); + fflush(stderr); + if (level_dispatch.empty()) { + ++round_idx; + continue; + } + + // Lazy-grow + (re-)publish the outputs buffer for this round. Sized to `level_dispatch.size()` i64 slots so each + // dispatched spec's `output_slot` is its position within `level_dispatch` (round-local), matching the gfx + // launcher's per-round output-buffer reuse. Re-publishing is required if the alloc grew across rounds because the + // runtime field stores a raw device pointer. + const std::size_t needed_output_bytes = level_dispatch.size() * sizeof(int64_t); + fprintf(stderr, "[trace dispatch_max_reducers] ROUND %zu outputs alloc check: needed=%zu capacity=%zu\n", round_idx, + needed_output_bytes, adstack_max_reducer_outputs_capacity_); + fflush(stderr); + if (needed_output_bytes > adstack_max_reducer_outputs_capacity_) { + Device::AllocParams alloc_params{}; + alloc_params.size = std::max(needed_output_bytes, 2 * adstack_max_reducer_outputs_capacity_); + alloc_params.host_read = false; + alloc_params.host_write = true; + alloc_params.export_sharing = false; + alloc_params.usage = AllocUsage::Storage; + fprintf(stderr, "[trace dispatch_max_reducers] allocating outputs buffer size=%zu\n", + static_cast(alloc_params.size)); + fflush(stderr); + DeviceAllocation new_alloc; + RhiResult res = llvm_device()->allocate_memory(alloc_params, &new_alloc); + fprintf(stderr, "[trace dispatch_max_reducers] allocate_memory returned res=%d\n", (int)res); + fflush(stderr); + QD_ERROR_IF(res != RhiResult::success, + "Failed to allocate {} bytes for adstack max reducer outputs buffer (err: {})", alloc_params.size, + int(res)); + adstack_max_reducer_outputs_alloc_ = std::make_unique(std::move(new_alloc)); + adstack_max_reducer_outputs_capacity_ = alloc_params.size; + } + void *outputs_dev_ptr = get_device_alloc_info_ptr(*adstack_max_reducer_outputs_alloc_); + fprintf(stderr, "[trace dispatch_max_reducers] ROUND %zu outputs_dev_ptr=%p; publishing to field_ptr=%p\n", + round_idx, outputs_dev_ptr, runtime_adstack_max_reducer_outputs_field_ptr_); + fflush(stderr); + copy_h2d(runtime_adstack_max_reducer_outputs_field_ptr_, &outputs_dev_ptr, sizeof(void *)); + + // Assign each ready spec's `output_slot` to its round-local index within `level_dispatch`, then h2d its params + + // body bytecode and invoke the single-threaded runtime function. + for (std::size_t i = 0; i < level_dispatch.size(); ++i) { + const std::size_t k = level_dispatch[i]; + pending[k].params.output_slot = static_cast(i); + const std::size_t needed_params_bytes = sizeof(LlvmAdStackMaxReducerDeviceParams); + if (needed_params_bytes > adstack_max_reducer_params_capacity_) { + Device::AllocParams alloc_params{}; + alloc_params.size = std::max(needed_params_bytes, 2 * adstack_max_reducer_params_capacity_); + alloc_params.host_read = false; + alloc_params.host_write = true; + alloc_params.export_sharing = false; + alloc_params.usage = AllocUsage::Storage; + DeviceAllocation new_alloc; + RhiResult res = llvm_device()->allocate_memory(alloc_params, &new_alloc); + QD_ERROR_IF(res != RhiResult::success, + "Failed to allocate {} bytes for adstack max reducer params buffer (err: {})", alloc_params.size, + int(res)); + adstack_max_reducer_params_alloc_ = std::make_unique(std::move(new_alloc)); + adstack_max_reducer_params_capacity_ = alloc_params.size; + } + void *params_dev_ptr = get_device_alloc_info_ptr(*adstack_max_reducer_params_alloc_); + copy_h2d(params_dev_ptr, &pending[k].params, needed_params_bytes); + + const std::size_t needed_bytecode_bytes = pending[k].body_bytecode.size(); + if (needed_bytecode_bytes > adstack_max_reducer_bytecode_capacity_) { + Device::AllocParams alloc_params{}; + alloc_params.size = std::max(needed_bytecode_bytes, 2 * adstack_max_reducer_bytecode_capacity_); + alloc_params.host_read = false; + alloc_params.host_write = true; + alloc_params.export_sharing = false; + alloc_params.usage = AllocUsage::Storage; + DeviceAllocation new_alloc; + RhiResult res = llvm_device()->allocate_memory(alloc_params, &new_alloc); + QD_ERROR_IF(res != RhiResult::success, + "Failed to allocate {} bytes for adstack max reducer bytecode buffer (err: {})", alloc_params.size, + int(res)); + adstack_max_reducer_bytecode_alloc_ = std::make_unique(std::move(new_alloc)); + adstack_max_reducer_bytecode_capacity_ = alloc_params.size; + } + void *bytecode_dev_ptr = get_device_alloc_info_ptr(*adstack_max_reducer_bytecode_alloc_); + copy_h2d(bytecode_dev_ptr, pending[k].body_bytecode.data(), needed_bytecode_bytes); + + fprintf(stderr, + "[trace dispatch_max_reducers] ROUND %zu spec_idx=%zu (k=%zu) calling " + "runtime_eval_adstack_max_reduce: llvm_runtime=%p ctx=%p params=%p bytecode=%p output_slot=%u " + "length=%u num_axes=%u body_node_count=%u\n", + round_idx, i, k, llvm_runtime_, runtime_context_ptr_for_reducer, params_dev_ptr, bytecode_dev_ptr, + pending[k].params.output_slot, pending[k].params.per_axis_length[0], pending[k].params.num_axes, + pending[k].params.body_node_count); + fflush(stderr); + runtime_jit->call("runtime_eval_adstack_max_reduce", llvm_runtime_, + runtime_context_ptr_for_reducer, params_dev_ptr, + bytecode_dev_ptr); + fprintf(stderr, + "[trace dispatch_max_reducers] ROUND %zu spec_idx=%zu (k=%zu) runtime_eval_adstack_max_reduce " + "returned\n", + round_idx, i, k); + fflush(stderr); + } + + // Read back this round's output slots. The runtime function writes int64 values at `outputs[output_slot]`; each + // spec's `output_slot` is its round-local index within `level_dispatch`, so the d2h covers exactly the round's + // dispatched specs. + std::vector outputs_host(level_dispatch.size(), 0); + fprintf(stderr, "[trace dispatch_max_reducers] ROUND %zu d2h outputs %zu bytes from %p\n", round_idx, + needed_output_bytes, outputs_dev_ptr); + fflush(stderr); + copy_d2h(outputs_host.data(), outputs_dev_ptr, needed_output_bytes); + for (std::size_t i = 0; i < level_dispatch.size(); ++i) { + const std::size_t k = level_dispatch[i]; + int64_t v = outputs_host[i]; + if (v == std::numeric_limits::min()) { + v = 0; + } + fprintf(stderr, "[trace dispatch_max_reducers] ROUND %zu readback k=%zu reg=%u stack=%d mor=%d -> %lld\n", + round_idx, k, pending[k].registry_id, pending[k].stack_id, pending[k].mor_node_idx, (long long)v); + fflush(stderr); + result[pending[k].cache_key] = v; + if (cache != nullptr) { + populate_max_reducer_body_observations(pending[k].reads, ctx, cache); + cache->record_max_reducer_eval(pending[k].registry_id, pending[k].stack_id, pending[k].mor_node_idx, v, + std::move(pending[k].reads)); + } + pending[k].dispatched = true; + ++dispatched_count; + } + ++round_idx; + } + fprintf(stderr, "[trace dispatch_max_reducers] EXIT dispatched=%zu dropped=%zu rounds=%zu result_size=%zu\n", + dispatched_count, dropped_count, round_idx, result.size()); + fflush(stderr); + + // Stash the result map on the executor so `publish_adstack_metadata` reads it for substitution per task. + current_max_reducer_results_ = result; + return result; +} + +void LlvmRuntimeExecutor::publish_adstack_lazy_claim_buffers(std::size_t num_tasks) { + if (num_tasks == 0) { + return; + } + // Cache the field-of-LLVMRuntime addresses for the row counter / bound row capacity array pointers. Resolved once per + // program lifetime; subsequent grows write the new array pointers directly to the cached addresses. + if (runtime_adstack_row_counters_field_ptr_ == nullptr) { + auto *const runtime_jit = get_runtime_jit_module(); + runtime_jit->call("runtime_get_adstack_lazy_claim_field_ptrs", llvm_runtime_); + runtime_adstack_row_counters_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id, result_buffer_cache_)); + runtime_adstack_bound_row_capacities_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 1, result_buffer_cache_)); + } + + auto grow_to = [&](DeviceAllocationUnique &alloc, std::size_t capacity_u32) { + Device::AllocParams params{}; + params.size = capacity_u32 * sizeof(uint32_t); + params.host_read = false; + params.host_write = false; + params.export_sharing = false; + params.usage = AllocUsage::Storage; + DeviceAllocation new_alloc; + RhiResult res = llvm_device()->allocate_memory(params, &new_alloc); + QD_ERROR_IF(res != RhiResult::success, "Failed to allocate {} bytes for adstack lazy-claim array (err: {})", + params.size, int(res)); + alloc = std::make_unique(std::move(new_alloc)); + }; + + bool grew = false; + if (num_tasks > adstack_lazy_claim_capacity_) { + std::size_t new_cap = std::max(num_tasks, 2 * adstack_lazy_claim_capacity_); + grow_to(adstack_row_counters_alloc_, new_cap); + grow_to(adstack_bound_row_capacities_alloc_, new_cap); + adstack_lazy_claim_capacity_ = new_cap; + grew = true; + } + void *row_counters_dev_ptr = get_device_alloc_info_ptr(*adstack_row_counters_alloc_); + void *bound_capacities_dev_ptr = get_device_alloc_info_ptr(*adstack_bound_row_capacities_alloc_); + + // After every grow, publish the new array pointers into the runtime so the codegen-emitted GEPs + // (`runtime->adstack_row_counters[task_codegen_id]` and `runtime->adstack_bound_row_capacities[task_codegen_id]`) + // resolve against the live allocations. Skipped between grows because the cached field address holds the same pointer + // value. + auto copy_h2d = [&](void *dst, const void *src, std::size_t bytes) { + if (config_.arch == Arch::cuda) { +#if defined(QD_WITH_CUDA) + CUDADriver::get_instance().memcpy_host_to_device(dst, const_cast(src), bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } else if (config_.arch == Arch::amdgpu) { +#if defined(QD_WITH_AMDGPU) + AMDGPUDriver::get_instance().memcpy_host_to_device(dst, const_cast(src), bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } else { + std::memcpy(dst, src, bytes); + } + }; + if (grew) { + copy_h2d(runtime_adstack_row_counters_field_ptr_, &row_counters_dev_ptr, sizeof(void *)); + copy_h2d(runtime_adstack_bound_row_capacities_field_ptr_, &bound_capacities_dev_ptr, sizeof(void *)); + } + + // Per-launch reset: zero the counter slots (each task's LCA-block atomic-rmw add starts from 0 and accumulates its + // own claims) and write UINT32_MAX into the capacity slots so the codegen-emitted bounds clamp is inert unless a + // later reducer dispatch overrides slots with tighter counts. Memset rather than per-slot store: the host pays one + // O(num_tasks) buffer fill per kernel-launch, regardless of arch. + std::vector zero_buf(num_tasks, 0u); + std::vector uint_max_buf(num_tasks, std::numeric_limits::max()); + copy_h2d(row_counters_dev_ptr, zero_buf.data(), num_tasks * sizeof(uint32_t)); + copy_h2d(bound_capacities_dev_ptr, uint_max_buf.data(), num_tasks * sizeof(uint32_t)); +} + +} // namespace quadrants::lang diff --git a/quadrants/runtime/llvm/adstack_lazy_claim/bound_eval.h b/quadrants/runtime/llvm/adstack_lazy_claim/bound_eval.h new file mode 100644 index 0000000000..1647697920 --- /dev/null +++ b/quadrants/runtime/llvm/adstack_lazy_claim/bound_eval.h @@ -0,0 +1,76 @@ +// Stage A of the LLVM sparse-adstack-heap lazy-claim pipeline: per-launch buffer publish + bound-expression +// evaluation. Allocates / clears the per-task lazy-claim arrays (`adstack_row_counters[num_tasks]` for the +// LCA-block atomic-rmw target, `adstack_bound_row_capacities[num_tasks]` for the codegen-emitted bounds +// clamp), then per task evaluates the captured `StaticAdStackBoundExpr` over `[0, length)` and publishes the +// gate-passing count into the per-task capacity slot. CPU walks the gating field on the host directly; CUDA +// / AMDGPU dispatch a single-thread device-side reducer (`runtime_eval_static_bound_count` in +// `runtime_module/runtime.cpp`). Captured `MaxOverRange` leaves are resolved up front by +// `dispatch_max_reducers_for_tasks` so the per-task sizer in Stage B sees them as `Const` substitutions. +// +// Caller responsibility (in `kernel_launcher.cpp` for each arch): invoke `publish_adstack_lazy_claim_buffers` +// once per kernel-launch before the first task dispatches, then per task call either +// `publish_per_task_bound_count_cpu` or `publish_per_task_bound_count_device` (arch-dispatched). Tasks +// without a captured `bound_expr` have those calls early-return with the inert UINT32_MAX sentinel that +// `publish_adstack_lazy_claim_buffers` wrote. +// +// All entry points are member methods of `LlvmRuntimeExecutor` and stay declared in +// `quadrants/runtime/llvm/llvm_runtime_executor.h`. This header carries only the file-private helpers +// shared between the stage's translation unit and (potentially) future cross-stage callers. + +#pragma once + +#include + +#include "quadrants/ir/static_adstack_bound_reducer_device.h" +#include "quadrants/ir/stmt_op_types.h" + +namespace quadrants::lang { + +namespace { + +// Encode the captured `BinaryOpType` (stored as int in `cmp_op`) and evaluate against typed operands. Mirrors the +// SPIR-V reducer's `OpSwitch` over the same encoding. +template +inline bool eval_cmp(int cmp_op, T lhs, T rhs) { + switch (static_cast(cmp_op)) { + case BinaryOpType::cmp_lt: + return lhs < rhs; + case BinaryOpType::cmp_le: + return lhs <= rhs; + case BinaryOpType::cmp_gt: + return lhs > rhs; + case BinaryOpType::cmp_ge: + return lhs >= rhs; + case BinaryOpType::cmp_eq: + return lhs == rhs; + case BinaryOpType::cmp_ne: + return lhs != rhs; + default: + return false; + } +} + +// Encode the captured `BinaryOpType` into the 0-5 numeric range the LLVM device reducer's switch consumes. Mirrors the +// SPIR-V reducer's `encode_cmp_op` mapping at `quadrants/runtime/gfx/adstack_bound_reducer_launch.cpp`. +inline uint32_t encode_cmp_op_for_llvm_reducer(int captured_cmp_op) { + switch (static_cast(captured_cmp_op)) { + case BinaryOpType::cmp_lt: + return kLlvmReducerCmpLt; + case BinaryOpType::cmp_le: + return kLlvmReducerCmpLe; + case BinaryOpType::cmp_gt: + return kLlvmReducerCmpGt; + case BinaryOpType::cmp_ge: + return kLlvmReducerCmpGe; + case BinaryOpType::cmp_eq: + return kLlvmReducerCmpEq; + case BinaryOpType::cmp_ne: + return kLlvmReducerCmpNe; + default: + return std::numeric_limits::max(); + } +} + +} // namespace + +} // namespace quadrants::lang diff --git a/quadrants/runtime/llvm/adstack_lazy_claim/heap_grow.cpp b/quadrants/runtime/llvm/adstack_lazy_claim/heap_grow.cpp new file mode 100644 index 0000000000..c53b52c83a --- /dev/null +++ b/quadrants/runtime/llvm/adstack_lazy_claim/heap_grow.cpp @@ -0,0 +1,295 @@ +// Stage C of the LLVM sparse-adstack-heap lazy-claim pipeline: heap-allocation lifecycle. See +// `adstack_lazy_claim/heap_grow.h` for the stage-level documentation. + +#include "quadrants/runtime/llvm/adstack_lazy_claim/heap_grow.h" + +#include "quadrants/runtime/llvm/llvm_runtime_executor.h" +#include "quadrants/program/adstack_size_expr_eval.h" +#include "quadrants/program/program.h" + +#include +#include +#include +#include +#include + +#include "quadrants/program/launch_context_builder.h" +#include "quadrants/program/program_impl.h" +#include "quadrants/rhi/llvm/llvm_device.h" + +#include "quadrants/platform/cuda/detect_cuda.h" +#include "quadrants/rhi/cuda/cuda_driver.h" + +#include "quadrants/platform/amdgpu/detect_amdgpu.h" +#include "quadrants/rhi/amdgpu/amdgpu_driver.h" + +namespace quadrants::lang { + +void LlvmRuntimeExecutor::ensure_adstack_heap_int(std::size_t needed_bytes) { + if (needed_bytes == 0 || needed_bytes <= adstack_heap_size_int_) { + return; + } + std::size_t new_size = std::max(needed_bytes, std::size_t(2) * adstack_heap_size_int_); + + Device::AllocParams params{}; + params.size = new_size; + params.host_read = false; + params.host_write = false; + params.export_sharing = false; + params.usage = AllocUsage::Storage; + DeviceAllocation new_alloc; + RhiResult res = llvm_device()->allocate_memory(params, &new_alloc); + QD_ERROR_IF(res != RhiResult::success, + "Failed to allocate {} bytes for the adstack int heap (err: {}). Consider lowering " + "`ad_stack_size` or the per-kernel reverse-mode adstack count.", + new_size, int(res)); + void *new_ptr = get_device_alloc_info_ptr(new_alloc); + auto new_guard = std::make_unique(std::move(new_alloc)); + + // The split-heap field-of-LLVMRuntime addresses are cached together by `ensure_adstack_heap_float` on its first grow + // (the same `runtime_get_adstack_split_heap_field_ptrs` getter returns all four addresses - float-buffer, float-size, + // int-buffer, int-size - in fixed slot order). On a fresh executor where this is the very first split-heap call, + // resolve the addresses here so we can publish independently of the float heap path. + if (runtime_adstack_heap_buffer_int_field_ptr_ == nullptr) { + auto *const runtime_jit = get_runtime_jit_module(); + runtime_jit->call("runtime_get_adstack_split_heap_field_ptrs", llvm_runtime_); + runtime_adstack_heap_buffer_float_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id, result_buffer_cache_)); + runtime_adstack_heap_size_float_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 1, result_buffer_cache_)); + runtime_adstack_heap_buffer_int_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 2, result_buffer_cache_)); + runtime_adstack_heap_size_int_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 3, result_buffer_cache_)); + } + uint64 size_u64 = static_cast(new_size); + if (config_.arch == Arch::cuda) { +#if defined(QD_WITH_CUDA) + CUDADriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_buffer_int_field_ptr_, &new_ptr, + sizeof(void *)); + CUDADriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_size_int_field_ptr_, &size_u64, + sizeof(uint64)); +#else + QD_NOT_IMPLEMENTED; +#endif + } else if (config_.arch == Arch::amdgpu) { +#if defined(QD_WITH_AMDGPU) + AMDGPUDriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_buffer_int_field_ptr_, &new_ptr, + sizeof(void *)); + AMDGPUDriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_size_int_field_ptr_, &size_u64, + sizeof(uint64)); +#else + QD_NOT_IMPLEMENTED; +#endif + } else { + *reinterpret_cast(runtime_adstack_heap_buffer_int_field_ptr_) = new_ptr; + *reinterpret_cast(runtime_adstack_heap_size_int_field_ptr_) = size_u64; + } + + adstack_heap_alloc_int_ = std::move(new_guard); + adstack_heap_size_int_ = new_size; +} + +void LlvmRuntimeExecutor::ensure_per_task_float_heap_post_reducer(std::size_t task_index, + const AdStackSizingInfo &ad_stack, + std::size_t num_threads, + LaunchContextBuilder *ctx) { + // Skip when the task has no float heap need (no f32 allocas, or analysis didn't capture a gate so we wouldn't have + // routed it through the lazy float path on the codegen side). + if (!ad_stack.bound_expr.has_value() || ad_stack.per_thread_stride_float == 0) { + return; + } + + // Read the per-task count the reducer published. On CPU the capacity buffer is host-resident; on CUDA / AMDGPU it's + // device memory and the read is a small (4-byte) DtoH per task. Cost is dominated by the actual main kernel. + uint32_t count = std::numeric_limits::max(); + if (adstack_bound_row_capacities_alloc_) { + void *capacities_dev_ptr = get_device_alloc_info_ptr(*adstack_bound_row_capacities_alloc_); + char *slot_ptr = static_cast(capacities_dev_ptr) + task_index * sizeof(uint32_t); + if (config_.arch == Arch::cuda) { +#if defined(QD_WITH_CUDA) + CUDADriver::get_instance().memcpy_device_to_host(&count, slot_ptr, sizeof(uint32_t)); +#else + QD_NOT_IMPLEMENTED; +#endif + } else if (config_.arch == Arch::amdgpu) { +#if defined(QD_WITH_AMDGPU) + AMDGPUDriver::get_instance().memcpy_device_to_host(&count, slot_ptr, sizeof(uint32_t)); +#else + QD_NOT_IMPLEMENTED; +#endif + } else { + count = *reinterpret_cast(slot_ptr); + } + } + + // Floor at 1 row when the captured count is zero (no thread passed the gate this launch). The codegen-emitted bounds + // clamp keeps `claimed_row` in [0, count-1] so threads that miss the gate never reach the LCA-block claim - the heap + // row stays unused. A 1-row allocation is cheap and keeps the heap pointer non-null. Clip by the captured + // compile-time loop trip count when known: each iteration claims at most one row at the LCA-block (one `atomic_add` + // per gating iteration), so the heap needs at most `loop_iter_static` rows regardless of how many cells of an + // oversized gating SNode the reducer counted. The analyzer leaves `loop_iter_static == 0` for runtime-bounded loops + // and for CPU LLVM tasks whose `[begin_value, end_value)` is a post-chunking subrange (the unclipped reducer count is + // the right upper bound there). + std::size_t effective_rows = + (count == std::numeric_limits::max()) ? num_threads : std::max(count, 1); + if (count != std::numeric_limits::max() && ad_stack.bound_expr.has_value()) { + // Shared with the SPIR-V launcher: see `clip_effective_rows_by_loop_trip_count` in + // `program/adstack_size_expr_eval.cpp`. LLVM dispatches one thread per loop iteration without the + // SPIR-V dispatch-cap-driven serialisation, so pass `numeric_limits::max()` to disable the + // dispatched-threads ceiling - any positive trip-count value is a sound upper bound on row claims + // here. `numeric_limits::max()` is the ceiling sentinel `clip_effective_rows_by_loop_trip_count` + // documents. + Program *prog = (program_impl_ != nullptr) ? program_impl_->program : nullptr; + clip_effective_rows_by_loop_trip_count(effective_rows, *ad_stack.bound_expr, + std::numeric_limits::max(), prog, ctx); + } + // The per-thread float stride (in bytes) was just published into `runtime->adstack_per_thread_stride_float` by the + // matching `publish_adstack_metadata` call earlier in this task's per-task block. We stash the value host-side so + // we can read it directly here instead of paying a sync DtoH on every bound_expr task. The launcher pairs publish + // + reducer + post-reducer per task with no intervening publish for another task, so the stash is accurate at this + // call site. `AdStackSizingInfo::per_thread_stride_float` from the analysis pre-pass is in entry-count units + // (`2 * max_size`), not bytes, and would massively undersize the heap. + uint64_t stride_float_bytes_u64 = static_cast(last_published_stride_float_bytes_); + const std::size_t needed_bytes = effective_rows * static_cast(stride_float_bytes_u64); + // `QD_DEBUG_ADSTACK=1` opt-in diagnostic. Persistent so memory regressions can be debugged without re-instrumenting. + if (std::getenv("QD_DEBUG_ADSTACK")) { + const char *src = (count == std::numeric_limits::max()) + ? "worst_case_num_threads" + : (count == 0 ? "reducer_zero_floored" : "reducer_count"); + std::fprintf(stderr, + "[adstack_heap] arch=llvm task_idx=%zu kind=F src=%s effective_rows=%zu stride=%llu " + "required_bytes=%zu (%.2f MB)\n", + task_index, src, effective_rows, static_cast(stride_float_bytes_u64), needed_bytes, + double(needed_bytes) / (1024.0 * 1024.0)); + std::fflush(stderr); + } + ensure_adstack_heap_float(needed_bytes); +} + +void LlvmRuntimeExecutor::ensure_adstack_heap_float(std::size_t needed_bytes) { + if (needed_bytes == 0 || needed_bytes <= adstack_heap_size_float_) { + return; + } + // Mirror `ensure_adstack_heap`'s amortised-doubling growth and grow-on-demand semantics. The float heap is allocated + // independently from the combined heap so a kernel with bound_expr tasks can shrink the combined slice to int-only + // while still backing float allocas at `row_id_var * stride_float + float_offset`. + std::size_t new_size = std::max(needed_bytes, std::size_t(2) * adstack_heap_size_float_); + + Device::AllocParams params{}; + params.size = new_size; + params.host_read = false; + params.host_write = false; + params.export_sharing = false; + params.usage = AllocUsage::Storage; + DeviceAllocation new_alloc; + RhiResult res = llvm_device()->allocate_memory(params, &new_alloc); + QD_ERROR_IF(res != RhiResult::success, + "Failed to allocate {} bytes for the adstack float heap (err: {}). Consider lowering " + "`ad_stack_size` or the per-kernel reverse-mode adstack count.", + new_size, int(res)); + void *new_ptr = get_device_alloc_info_ptr(new_alloc); + auto new_guard = std::make_unique(std::move(new_alloc)); + + // Resolve and cache the field-of-LLVMRuntime addresses for the split-heap fields on first grow. The + // `runtime_get_adstack_split_heap_field_ptrs` helper returns four addresses in fixed slot order: float-buffer-ptr, + // float-size, int-buffer-ptr, int-size. We only consume the float pair here; the int half is reserved for a future + // symmetric `ensure_adstack_heap_int` if it becomes useful (today the int allocas in bound_expr tasks ride the + // combined heap with a smaller stride). + if (runtime_adstack_heap_buffer_float_field_ptr_ == nullptr) { + auto *const runtime_jit = get_runtime_jit_module(); + runtime_jit->call("runtime_get_adstack_split_heap_field_ptrs", llvm_runtime_); + runtime_adstack_heap_buffer_float_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id, result_buffer_cache_)); + runtime_adstack_heap_size_float_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 1, result_buffer_cache_)); + runtime_adstack_heap_buffer_int_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 2, result_buffer_cache_)); + runtime_adstack_heap_size_int_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 3, result_buffer_cache_)); + } + uint64 size_u64 = static_cast(new_size); + if (config_.arch == Arch::cuda) { +#if defined(QD_WITH_CUDA) + CUDADriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_buffer_float_field_ptr_, &new_ptr, + sizeof(void *)); + CUDADriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_size_float_field_ptr_, &size_u64, + sizeof(uint64)); +#else + QD_NOT_IMPLEMENTED; +#endif + } else if (config_.arch == Arch::amdgpu) { +#if defined(QD_WITH_AMDGPU) + AMDGPUDriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_buffer_float_field_ptr_, &new_ptr, + sizeof(void *)); + AMDGPUDriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_size_float_field_ptr_, &size_u64, + sizeof(uint64)); +#else + QD_NOT_IMPLEMENTED; +#endif + } else { + *reinterpret_cast(runtime_adstack_heap_buffer_float_field_ptr_) = new_ptr; + *reinterpret_cast(runtime_adstack_heap_size_float_field_ptr_) = size_u64; + } + + adstack_heap_alloc_float_ = std::move(new_guard); + adstack_heap_size_float_ = new_size; +} + +void LlvmRuntimeExecutor::check_adstack_overflow() { + // Called from `synchronize()` on every sync, plus other Quadrants Python entry points wired in + // `Program::check_adstack_overflow_and_raise`. The flag lives in pinned host memory (allocated at + // `materialize_runtime`); polling is a relaxed atomic exchange on the cached host pointer via + // `std::atomic` reinterpret_cast - no DtoH, no JIT call, no sync drain. Available on all backends because + // the pinned-host memory is in the host process address space regardless of where the kernel that wrote it ran. + // The reinterpret_cast is portable because `std::atomic` is layout-compatible with `int64_t` on every + // target (verified by the static_assert below); see also Itanium ABI / MSVC ABI lock-free guarantees. + // + // Returns early when the slot has not been allocated yet (e.g. a C++ test that constructs Program without + // materializing the runtime and then triggers `Program::finalize -> synchronize`). + static_assert(std::atomic::is_always_lock_free, + "std::atomic must be lock-free for the reinterpret_cast pattern below to be portable"); + if (adstack_overflow_flag_host_ptr_ == nullptr) { + return; + } + int64_t flag = + reinterpret_cast *>(adstack_overflow_flag_host_ptr_)->exchange(0, std::memory_order_relaxed); + if (flag == 0) { + return; + } + // Drain the companion task-id slot in the same poll. Both slots cleared so the next overflow records a fresh + // identity. `task_id == 0` means the kernel that overflowed pre-dates the registry wiring or its + // `ad_stack.registry_id` was unset for any reason (e.g. a deserialised offline-cache task that has not yet been + // re-registered); the diagnose helper falls through to the generic dual-cause message in that case. + uint32_t task_id = 0; + if (adstack_overflow_task_id_host_ptr_ != nullptr) { + int64_t recorded = reinterpret_cast *>(adstack_overflow_task_id_host_ptr_) + ->exchange(0, std::memory_order_relaxed); + task_id = static_cast(recorded); + } + Program *prog = (program_impl_ != nullptr) ? program_impl_->program : nullptr; + std::string diagnostic; + if (prog != nullptr) { + auto diag = prog->adstack_cache().diagnose_adstack_overflow(task_id); + diagnostic = std::move(diag.message); + // Auto-invalidate the per-task metadata caches when the synchronous sizer rerun confirmed the cache is stale + // (DLPack-bypass cause). The current run is corrupted (we are about to raise), but the next launch's sizer + // reruns from scratch against the live (mutated) state and the kernel runs to completion without further + // user intervention. Unknown / Quadrants-bug cases skip the invalidation so a real sizer bug is not masked + // by silent recompute. + if (diag.confirmed_invalid_cache) { + prog->adstack_cache().invalidate_all_per_task(); + } + } else { + diagnostic = + "Adstack overflow: a reverse-mode autodiff kernel pushed more elements than the adstack capacity " + "allows."; + } + throw QuadrantsAssertionError( + "Adstack overflow: a reverse-mode autodiff kernel pushed more elements " + "than the adstack capacity allows. Raised at the next Quadrants Python " + "entry rather than at the offending kernel launch.\n" + + diagnostic); +} + +} // namespace quadrants::lang diff --git a/quadrants/runtime/llvm/adstack_lazy_claim/heap_grow.h b/quadrants/runtime/llvm/adstack_lazy_claim/heap_grow.h new file mode 100644 index 0000000000..45ef2f2617 --- /dev/null +++ b/quadrants/runtime/llvm/adstack_lazy_claim/heap_grow.h @@ -0,0 +1,22 @@ +// Stage C of the LLVM sparse-adstack-heap lazy-claim pipeline: heap-allocation lifecycle. Sizes the float / +// int adstack heaps from the per-task counts Stage A published and the per-kind strides Stage B resolved, so +// each heap holds exactly `count * stride` bytes per dispatch instead of the dispatched-threads worst case. +// `ensure_adstack_heap_float` and `ensure_adstack_heap_int` are the amortised-doubling growers; on first +// grow they cache the four split-heap field-of-LLVMRuntime addresses through +// `runtime_get_adstack_split_heap_field_ptrs`. `ensure_per_task_float_heap_post_reducer` reads the +// reducer-published per-task count and is the bridge from Stage A's gate output to the lazy float-heap +// sizing. `check_adstack_overflow` is the sync-time consumer that raises when a kernel pushed past the +// captured capacity; it polls the pinned-host overflow flag with a relaxed atomic exchange, no DtoH and no +// JIT call. +// +// All entry points are member methods of `LlvmRuntimeExecutor` and stay declared in +// `quadrants/runtime/llvm/llvm_runtime_executor.h`. This header is reserved for future cross-stage helpers; +// the four methods presently share no helpers other than the runtime field-pointer cache they all consult. + +#pragma once + +namespace quadrants::lang { + +// Reserved for future cross-stage helper declarations. + +} // namespace quadrants::lang diff --git a/quadrants/runtime/llvm/adstack_lazy_claim/metadata_publish.cpp b/quadrants/runtime/llvm/adstack_lazy_claim/metadata_publish.cpp new file mode 100644 index 0000000000..8a0b512280 --- /dev/null +++ b/quadrants/runtime/llvm/adstack_lazy_claim/metadata_publish.cpp @@ -0,0 +1,602 @@ +// Stage B of the LLVM sparse-adstack-heap lazy-claim pipeline: per-launch metadata publish. See +// `adstack_lazy_claim/metadata_publish.h` for the stage-level documentation. + +#include "quadrants/runtime/llvm/adstack_lazy_claim/metadata_publish.h" + +#include "quadrants/runtime/llvm/llvm_runtime_executor.h" +#include "quadrants/program/adstack_size_expr_eval.h" +#include "quadrants/program/program.h" + +#include +#include +#include +#include +#include + +#include "quadrants/ir/adstack_size_expr_device.h" +#include "quadrants/ir/type_factory.h" +#include "quadrants/program/launch_context_builder.h" +#include "quadrants/program/program_impl.h" +#include "quadrants/rhi/llvm/llvm_device.h" + +#include "quadrants/platform/cuda/detect_cuda.h" +#include "quadrants/rhi/cuda/cuda_driver.h" +#if defined(QD_WITH_CUDA) +#include "quadrants/rhi/cuda/cuda_context.h" +#endif + +#include "quadrants/platform/amdgpu/detect_amdgpu.h" +#include "quadrants/rhi/amdgpu/amdgpu_driver.h" +#if defined(QD_WITH_AMDGPU) +#include "quadrants/rhi/amdgpu/amdgpu_context.h" +#endif + +namespace quadrants::lang { + +std::size_t LlvmRuntimeExecutor::publish_adstack_metadata(const AdStackSizingInfo &ad_stack, + std::size_t num_threads, + LaunchContextBuilder *ctx, + void *device_runtime_context_ptr) { + const auto n_stacks = ad_stack.allocas.size(); + if (n_stacks == 0 || num_threads == 0) { + return 0; + } + auto align_up_8 = [](std::size_t n) -> std::size_t { return (n + 7u) & ~std::size_t{7u}; }; + // Allocate / grow the two device-side metadata arrays. Capacity is in u64 entries, kept at or above n_stacks. + // On GPU these buffers are written exclusively by the device-side sizer kernel (`runtime_eval_adstack_size_expr`); + // on CPU the host evaluator writes them directly via `std::memcpy`. Either way the pointers published into + // `runtime->adstack_offsets` / `adstack_max_sizes` stay stable across launches unless we grow here. + auto grow_to = [&](DeviceAllocationUnique &alloc, std::size_t capacity_u64) { + Device::AllocParams params{}; + params.size = capacity_u64 * sizeof(uint64_t); + params.host_read = false; + params.host_write = false; + params.export_sharing = false; + params.usage = AllocUsage::Storage; + DeviceAllocation new_alloc; + RhiResult res = llvm_device()->allocate_memory(params, &new_alloc); + QD_ERROR_IF(res != RhiResult::success, "Failed to allocate {} bytes for adstack metadata array (err: {})", + params.size, int(res)); + alloc = std::make_unique(std::move(new_alloc)); + }; + if (n_stacks > adstack_metadata_capacity_) { + std::size_t new_cap = std::max(n_stacks, 2 * adstack_metadata_capacity_); + grow_to(adstack_offsets_alloc_, new_cap); + grow_to(adstack_max_sizes_alloc_, new_cap); + adstack_metadata_capacity_ = new_cap; + } + void *offsets_dev_ptr = get_device_alloc_info_ptr(*adstack_offsets_alloc_); + void *max_sizes_dev_ptr = get_device_alloc_info_ptr(*adstack_max_sizes_alloc_); + + auto copy_h2d = [&](void *dst, const void *src, std::size_t bytes) { + if (config_.arch == Arch::cuda) { +#if defined(QD_WITH_CUDA) + CUDADriver::get_instance().memcpy_host_to_device(dst, const_cast(src), bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } else if (config_.arch == Arch::amdgpu) { +#if defined(QD_WITH_AMDGPU) + AMDGPUDriver::get_instance().memcpy_host_to_device(dst, const_cast(src), bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } else { + std::memcpy(dst, src, bytes); + } + }; + auto copy_d2h = [&](void *dst, const void *src, std::size_t bytes) { + if (config_.arch == Arch::cuda) { +#if defined(QD_WITH_CUDA) + CUDADriver::get_instance().memcpy_device_to_host(dst, const_cast(src), bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } else if (config_.arch == Arch::amdgpu) { +#if defined(QD_WITH_AMDGPU) + AMDGPUDriver::get_instance().memcpy_device_to_host(dst, const_cast(src), bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } else { + std::memcpy(dst, src, bytes); + } + }; + + // Cache the runtime-field addresses on the first call; then publish the metadata-array pointers into the + // runtime struct. The stride field is written by the sizer on GPU and by this function on CPU, so we cache the + // address either way. + if (runtime_adstack_stride_field_ptr_ == nullptr) { + auto *const runtime_jit = get_runtime_jit_module(); + runtime_jit->call("runtime_get_adstack_metadata_field_ptrs", llvm_runtime_); + // Slot order: combined-stride, offsets, max_sizes, float-stride, int-stride. Slots 0/1/2 keep the legacy ordering + // for code paths that have not migrated to the split layout; slots 3/4 are new. + runtime_adstack_stride_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id, result_buffer_cache_)); + runtime_adstack_offsets_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 1, result_buffer_cache_)); + runtime_adstack_max_sizes_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 2, result_buffer_cache_)); + runtime_adstack_stride_float_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 3, result_buffer_cache_)); + runtime_adstack_stride_int_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 4, result_buffer_cache_)); + } + // The pointed-to scratch allocations are stable across launches (only `grow_to` swaps them). Skip the per-launch + // h2d that publishes the pointer values whenever they have not changed since the last call. On HIP / CUDA each + // skipped pointer-publish is one queue round-trip the launcher would otherwise pay; on a typical reverse-mode + // sweep this fires thousands of times. + if (offsets_dev_ptr != adstack_offsets_dev_ptr_published_) { + copy_h2d(runtime_adstack_offsets_field_ptr_, &offsets_dev_ptr, sizeof(void *)); + adstack_offsets_dev_ptr_published_ = offsets_dev_ptr; + } + if (max_sizes_dev_ptr != adstack_max_sizes_dev_ptr_published_) { + copy_h2d(runtime_adstack_max_sizes_field_ptr_, &max_sizes_dev_ptr, sizeof(void *)); + adstack_max_sizes_dev_ptr_published_ = max_sizes_dev_ptr; + } + + std::size_t stride = 0; + const bool is_gpu_llvm = (config_.arch == Arch::cuda || config_.arch == Arch::amdgpu); + + // Shared GPU async publish helper: pack `[stride_combined, stride_float, stride_int, offsets[n_stacks], + // max_sizes[n_stacks]]` into the pinned-host scratch (grow on demand, double-amortised), then issue 5 async H2Ds + // on the active stream and record the completion event. Used by both the host-eval branch (CUDA / AMDGPU + // resolvable size_exprs) and the on-device-sizer cache-hit branch. The driver's H2D DMA reads from the pinned + // bytes at execution time, so a `wait_pending()` at the top of the next call defends against an unusual + // interleaving where the GPU queue is backlogged and the next launch enters before the previous launch's last + // copy has been consumed. Only callable when `is_gpu_llvm` is true. + auto publish_metadata_pinned_async = [&](const uint64_t *offsets_src, const uint64_t *max_sizes_src, + uint64_t stride_combined_u64, uint64_t stride_float_u64, + uint64_t stride_int_u64) { + const std::size_t header_bytes = 3 * sizeof(uint64_t); + const std::size_t array_bytes = n_stacks * sizeof(uint64_t); + const std::size_t total_bytes = header_bytes + 2 * array_bytes; + auto wait_pending = [this]() { + if (!pinned_metadata_event_pending_) { + return; + } +#if defined(QD_WITH_CUDA) + if (config_.arch == Arch::cuda) { + CUDADriver::get_instance().event_synchronize(pinned_metadata_event_); + } +#endif +#if defined(QD_WITH_AMDGPU) + if (config_.arch == Arch::amdgpu) { + AMDGPUDriver::get_instance().event_synchronize(pinned_metadata_event_); + } +#endif + pinned_metadata_event_pending_ = false; + }; + if (total_bytes > pinned_metadata_scratch_capacity_) { + wait_pending(); + if (pinned_metadata_scratch_ != nullptr) { +#if defined(QD_WITH_CUDA) + if (config_.arch == Arch::cuda) { + CUDADriver::get_instance().mem_free_host(pinned_metadata_scratch_); + } +#endif +#if defined(QD_WITH_AMDGPU) + if (config_.arch == Arch::amdgpu) { + AMDGPUDriver::get_instance().mem_free_host(pinned_metadata_scratch_); + } +#endif + pinned_metadata_scratch_ = nullptr; + } + const std::size_t new_capacity = std::max(total_bytes, 2 * pinned_metadata_scratch_capacity_); +#if defined(QD_WITH_CUDA) + if (config_.arch == Arch::cuda) { + CUDADriver::get_instance().mem_alloc_host(&pinned_metadata_scratch_, new_capacity); + } +#endif +#if defined(QD_WITH_AMDGPU) + if (config_.arch == Arch::amdgpu) { + // `hipHostMallocDefault == 0`. Coherent / portable / write-combined flags are intentionally not set; the + // workload is small payloads written linearly by the host and DMA-read by the GPU once. + AMDGPUDriver::get_instance().mem_alloc_host(&pinned_metadata_scratch_, new_capacity, 0u); + } +#endif + pinned_metadata_scratch_capacity_ = new_capacity; + } + if (pinned_metadata_event_ == nullptr) { + // `cuEventCreate` flag `0` (CU_EVENT_DEFAULT) means timing-enabled, which the driver costs us nothing to set + // up here and lets future profilers attach without re-creating the event. `hipEventCreateWithFlags` takes + // the same encoding. +#if defined(QD_WITH_CUDA) + if (config_.arch == Arch::cuda) { + CUDADriver::get_instance().event_create(&pinned_metadata_event_, 0u); + } +#endif +#if defined(QD_WITH_AMDGPU) + if (config_.arch == Arch::amdgpu) { + AMDGPUDriver::get_instance().event_create(&pinned_metadata_event_, 0u); + } +#endif + } + wait_pending(); + auto *pinned = static_cast(pinned_metadata_scratch_); + pinned[0] = stride_combined_u64; + pinned[1] = stride_float_u64; + pinned[2] = stride_int_u64; + std::memcpy(pinned + 3, offsets_src, array_bytes); + std::memcpy(pinned + 3 + n_stacks, max_sizes_src, array_bytes); + // Queue the metadata copies on the stream the subsequent main-kernel dispatch will run on, so the GPU + // stream-orders the copies before the kernel reads `adstack_max_sizes` etc. CUDA: `CUDAContext::get_stream()` + // (configurable via `set_stream`, defaults to the null stream); AMDGPU: always the default stream because + // `AMDGPUContext::launch` passes `nullptr` to `hipLaunchKernel`. +#if defined(QD_WITH_CUDA) + if (config_.arch == Arch::cuda) { + void *active_stream = CUDAContext::get_instance().get_stream(); + CUDADriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_field_ptr_, pinned, + sizeof(uint64_t), active_stream); + if (runtime_adstack_stride_float_field_ptr_ != nullptr) { + CUDADriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_float_field_ptr_, pinned + 1, + sizeof(uint64_t), active_stream); + } + if (runtime_adstack_stride_int_field_ptr_ != nullptr) { + CUDADriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_int_field_ptr_, pinned + 2, + sizeof(uint64_t), active_stream); + } + CUDADriver::get_instance().memcpy_host_to_device_async(offsets_dev_ptr, pinned + 3, array_bytes, active_stream); + CUDADriver::get_instance().memcpy_host_to_device_async(max_sizes_dev_ptr, pinned + 3 + n_stacks, array_bytes, + active_stream); + CUDADriver::get_instance().event_record(pinned_metadata_event_, active_stream); + } +#endif +#if defined(QD_WITH_AMDGPU) + if (config_.arch == Arch::amdgpu) { + void *active_stream = nullptr; + AMDGPUDriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_field_ptr_, pinned, + sizeof(uint64_t), active_stream); + if (runtime_adstack_stride_float_field_ptr_ != nullptr) { + AMDGPUDriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_float_field_ptr_, pinned + 1, + sizeof(uint64_t), active_stream); + } + if (runtime_adstack_stride_int_field_ptr_ != nullptr) { + AMDGPUDriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_int_field_ptr_, pinned + 2, + sizeof(uint64_t), active_stream); + } + AMDGPUDriver::get_instance().memcpy_host_to_device_async(offsets_dev_ptr, pinned + 3, array_bytes, active_stream); + AMDGPUDriver::get_instance().memcpy_host_to_device_async(max_sizes_dev_ptr, pinned + 3 + n_stacks, array_bytes, + active_stream); + AMDGPUDriver::get_instance().event_record(pinned_metadata_event_, active_stream); + } +#endif + pinned_metadata_event_pending_ = true; + }; + + // Host-eval fast path. The on-device sizer kernel exists to handle one specific leaf, `ExternalTensorRead`, + // whose ndarray data lives in GPU-private memory (`cudaMalloc` / `hipMalloc`, no UVA fallback) and thus + // cannot be touched from the host. Every other SizeExpr leaf - `Const`, `BoundVariable`, + // `ExternalTensorShape`, `FieldLoad` - is host-resolvable through the existing `evaluate_adstack_size_expr` + // path, so when the kernel's SizeExprs are all `ExternalTensorRead`-free we can skip the encode + bytecode + // h2d + sizer-kernel launch + d2h-stride pipeline entirely and write the metadata directly via `copy_h2d`. + // On CUDA the saved `cuMemcpyDtoH` for the per-launch stride readback is the dominant cost: every reverse- + // mode kernel launch in a 100-substep test paid one such synchronous DtoH each, and that compound stall + // accounted for the bulk of the GPU launch overhead under adstack mode. The condition is computed once per + // launch by scanning each stack's `nodes` vector for an `ExternalTensorRead` leaf; the scan is O(total + // SizeExpr nodes), well below the cost of the cheapest h2d / d2h on any LLVM GPU backend. + bool all_size_exprs_host_resolvable = true; + for (std::size_t i = 0; i < n_stacks && all_size_exprs_host_resolvable; ++i) { + if (i >= ad_stack.size_exprs.size()) { + continue; + } + for (const auto &node : ad_stack.size_exprs[i].nodes) { + if (static_cast(node.kind) == SizeExpr::Kind::ExternalTensorRead) { + all_size_exprs_host_resolvable = false; + break; + } + } + } + const bool use_host_eval = !is_gpu_llvm || all_size_exprs_host_resolvable; + // Per-kind byte strides resolved either host-side (host-eval branch) or by reading back from the device runtime + // struct after the sizer kernel ran (GPU branch). Used below to size the float / int heaps independently for the + // unconditional split-heap layout. + std::size_t stride_float_bytes = 0; + std::size_t stride_int_bytes = 0; + if (use_host_eval) { + // CPU + GPU-without-ExternalTensorRead path: run the host evaluator directly. On CPU we use synchronous + // `copy_h2d` (just `std::memcpy` for that arch), but on CUDA / AMDGPU we ship the same payload through + // pinned-host memory via async `cuMemcpyHtoDAsync` / `hipMemcpyHtoDAsync` so the host returns immediately + // after queueing the copies on the default stream and the subsequent main-kernel launch (also on the + // default stream) stream-orders after the copies. The synchronous `cuMemcpyHtoD_v2` path used to block + // the host on every one of the three writes we issue per launch; with thousands of reverse-mode launches + // per `test_differentiable_rigid` run, those serial host stalls were a measurable fraction of wallclock. + // `FieldLoad` is serviced by `SNodeRwAccessorsBank` regardless of arch. + // Guard `program_impl_->program` lookups against the C++-only-tests setup where `program_impl_` itself is null; + // the on-device branch below already does this and falls back to `max_size_compile_time`. + Program *prog = (program_impl_ != nullptr) ? program_impl_->program : nullptr; + // Span the per-stack `evaluate_adstack_size_expr` calls below with one shared read cache. + SizeExprLaunchScope launch_scope; + // Snapshot the dispatched-results map for this kernel before the per-stack walk. The body of any captured + // `MaxOverRange` may host-resolve a `FieldLoad` leaf via `read_field_with_launch_cache`, which dispatches a + // snode-reader kernel that reenters `dispatch_max_reducers_for_tasks` and clears `current_max_reducer_results_`. + // Reading the live executor field per stack would let that recursive clear turn `stack_id == 0` 's substitution + // branch into `stack_id == 1` 's empty-map fallback - whose host walk then trips the per-task sizer's `1<<24` cap + // on out-of-grammar shapes that the recognizer DID capture. Pin the snapshot by-value so the substitution loop + // stays self-consistent. + const auto local_max_reducer_results = current_max_reducer_results_; + std::vector host_max_sizes(n_stacks); + for (std::size_t i = 0; i < n_stacks; ++i) { + const SerializedSizeExpr *expr = (i < ad_stack.size_exprs.size()) ? &ad_stack.size_exprs[i] : nullptr; + int64_t v = -1; + if (expr != nullptr && !expr->nodes.empty() && prog != nullptr) { + // Substitute any captured `MaxOverRange` whose result the max-reducer dispatched into a `Const` before the host + // evaluator walks the tree. Mirrors `eval_per_task_metadata_on_host` on the SPIR-V side. The empty-results fast + // path passes the live `expr` pointer directly so `size_expr_cache_` (keyed by `SerializedSizeExpr *`) stays + // warm across launches; the non-empty branch builds a stack-local substituted tree and routes through + // `evaluate_adstack_size_expr_no_cache` so the transient pointer never aliases unrelated cache entries. + if (local_max_reducer_results.empty()) { + v = evaluate_adstack_size_expr(*expr, prog, ctx); + } else { + const SerializedSizeExpr substituted = substitute_precomputed_max_over_range( + *expr, ad_stack.registry_id, static_cast(i), local_max_reducer_results); + v = evaluate_adstack_size_expr_no_cache(substituted, prog, ctx); + } + } + if (v < 0) { + v = static_cast(ad_stack.allocas[i].max_size_compile_time); + } + host_max_sizes[i] = static_cast(std::max(v, 1)); + } + // Unconditional split-heap layout: float allocas live at `host_offsets[i]` within the float-only slice (addressed + // on the codegen side as `heap_float + row_id_var * stride_float + float_offset` for bound_expr tasks, or + // `heap_float + linear_tid * stride_float + float_offset` for non-bound_expr tasks); int allocas live at + // `host_offsets[i]` within the int-only slice (addressed as `heap_int + linear_tid * stride_int + int_offset`). + // Same scheme regardless of `bound_expr` so the heap layout matches the SPIR-V backend's unconditional split into + // `BufferType::AdStackHeapFloat` + `AdStackHeapInt`. The legacy combined-heap path is no longer used by the + // codegen; the combined stride / heap fields stay in the LLVMRuntime struct only as a transitional fallback for + // offline-cache-loaded kernels that predate the split, and the published `adstack_per_thread_stride` mirrors + // `stride_int` so any such kernel sees the smaller int-only stride. + std::vector host_offsets(n_stacks); + for (std::size_t i = 0; i < n_stacks; ++i) { + const std::size_t step = align_up_8(sizeof(int64_t) + ad_stack.allocas[i].entry_size_bytes * host_max_sizes[i]); + const bool is_float = ad_stack.allocas[i].heap_kind == AdStackAllocaInfo::HeapKind::Float; + host_offsets[i] = is_float ? stride_float_bytes : stride_int_bytes; + if (is_float) { + stride_float_bytes += step; + } else { + stride_int_bytes += step; + } + } + stride = stride_int_bytes; + uint64_t stride_combined_u64 = static_cast(stride); + uint64_t stride_float_u64 = static_cast(stride_float_bytes); + uint64_t stride_int_u64 = static_cast(stride_int_bytes); + if (!is_gpu_llvm) { + copy_h2d(offsets_dev_ptr, host_offsets.data(), n_stacks * sizeof(uint64_t)); + copy_h2d(max_sizes_dev_ptr, host_max_sizes.data(), n_stacks * sizeof(uint64_t)); + copy_h2d(runtime_adstack_stride_field_ptr_, &stride_combined_u64, sizeof(uint64_t)); + // Per-kind strides used by the split-heap codegen path; harmless when the codegen has not migrated yet (the + // kernel reads only the combined stride). Skipped when the cache is empty (first launch on a stale executor + // instance where `runtime_get_adstack_metadata_field_ptrs` populated only the legacy slots; the null check is + // defensive - any host writing to `nullptr` would crash with no diagnostic). + if (runtime_adstack_stride_float_field_ptr_ != nullptr) { + copy_h2d(runtime_adstack_stride_float_field_ptr_, &stride_float_u64, sizeof(uint64_t)); + } + if (runtime_adstack_stride_int_field_ptr_ != nullptr) { + copy_h2d(runtime_adstack_stride_int_field_ptr_, &stride_int_u64, sizeof(uint64_t)); + } + } else { + publish_metadata_pinned_async(host_offsets.data(), host_max_sizes.data(), stride_combined_u64, stride_float_u64, + stride_int_u64); + } + } else { + // GPU (CUDA / AMDGPU): encode the SizeExpr trees into device bytecode, upload, launch the sizer runtime + // function, read back just the computed stride. The sizer kernel writes `adstack_max_sizes[]`, + // `adstack_offsets[]`, and `adstack_per_thread_stride` directly into the runtime struct and the metadata + // arrays above - no further host-writes to those fields are needed this launch. + // + // Why this architecture rather than host-eval: on CUDA / AMDGPU the ndarray data lives in GPU-private memory + // (plain `cudaMalloc` / `hipMalloc`, not managed / unified), so the host evaluator's `ExternalTensorRead` + // deref reads garbage. Moving the interpreter on-device keeps the pointer semantics intact - it reads the + // data pointer out of `ctx->arg_buffer` (which the kernel will read too) and dereferences it where the + // memory lives, with no migration / readback of the ndarray payload itself. + // + // Per-task metadata cache fast path: the sizer kernel's output (offsets / max_sizes / strides) is a + // deterministic function of (a) the per-task `AdStackSizingInfo *` (compile-time bytecode shape, stable + // for the kernel's lifetime), (b) every SNode value a `FieldLoad` leaf reads, and (c) every ndarray + // value an `ExternalTensorRead` leaf reads. Each launcher (cpu / cuda / amdgpu) bumps + // `Program::snode_write_gen_` / `ndarray_data_gen_` for everything this kernel may mutate before + // calling here, so the per-source generation snapshots stored alongside the cached payload catch any + // input change between launches and force a fresh sizer dispatch when needed. On hit, the cached + // offsets / max_sizes / strides are republished into the runtime struct via the same `copy_h2d` paths + // the host-eval branch above uses, and the entire bytecode-encode + h2d + sizer-kernel launch + + // 3x DtoH-stride pipeline is skipped. The cost of the sizer dispatch + DtoH stalls is small per + // launch on CUDA / AMDGPU, but a long sequence of reverse-mode launches over the same kernel + // pays it once per launch; the cache amortises that to once per generation-bump. + Program *prog = (program_impl_ != nullptr) ? program_impl_->program : nullptr; + bool llvm_metadata_cache_hit = false; + if (prog != nullptr) { + AdStackCache::LlvmPerTaskAdStackCacheEntry entry; + if (prog->adstack_cache().try_llvm_per_task_ad_stack_cache_hit(static_cast(&ad_stack), ctx, + entry)) { + QD_ASSERT(entry.offsets.size() == n_stacks && entry.max_sizes.size() == n_stacks); + // Publish the cached payload through the pinned-host async pipeline shared with the host-eval + // branch above: one pinned-scratch pack + five `memcpy_host_to_device_async` issued on the same + // stream the main kernel will dispatch on, ordered behind the previous launch's + // `pinned_metadata_event_pending_` wait. Packing the same `[stride_combined, stride_float, + // stride_int, offsets[n_stacks], max_sizes[n_stacks]]` shape keeps both branches' DMA pattern + // identical and removes the per-launch sync round-trips a `copy_h2d` would otherwise impose; on + // CPU `copy_h2d` is `memcpy` already so we keep the direct path there. + if (!is_gpu_llvm) { + copy_h2d(offsets_dev_ptr, entry.offsets.data(), n_stacks * sizeof(uint64_t)); + copy_h2d(max_sizes_dev_ptr, entry.max_sizes.data(), n_stacks * sizeof(uint64_t)); + copy_h2d(runtime_adstack_stride_field_ptr_, &entry.stride_combined, sizeof(uint64_t)); + if (runtime_adstack_stride_float_field_ptr_ != nullptr) { + copy_h2d(runtime_adstack_stride_float_field_ptr_, &entry.stride_float, sizeof(uint64_t)); + } + if (runtime_adstack_stride_int_field_ptr_ != nullptr) { + copy_h2d(runtime_adstack_stride_int_field_ptr_, &entry.stride_int, sizeof(uint64_t)); + } + } else { + publish_metadata_pinned_async(entry.offsets.data(), entry.max_sizes.data(), entry.stride_combined, + entry.stride_float, entry.stride_int); + } + stride = static_cast(entry.stride_combined); + stride_float_bytes = static_cast(entry.stride_float); + stride_int_bytes = static_cast(entry.stride_int); + llvm_metadata_cache_hit = true; + } + } + if (!llvm_metadata_cache_hit) { + std::vector bytecode; + if (program_impl_ != nullptr && program_impl_->program != nullptr) { + bytecode = encode_adstack_size_expr_device_bytecode(ad_stack, program_impl_->program, ctx, + current_max_reducer_results_); + } else { + // No program attached (rare: C++-only tests that construct Program without a full runtime). Fall through + // to compile-time bounds by emitting an empty-tree bytecode - the device interpreter sees + // `root_node_idx == -1` for every stack and routes to `max_size_compile_time`. + bytecode = encode_adstack_size_expr_device_bytecode(ad_stack, nullptr, ctx, current_max_reducer_results_); + } + // Grow the scratch buffer if the bytecode outgrew the cached capacity. Amortised doubling keeps the + // allocation traffic O(log max_bytecode_bytes) across a run. + const std::size_t bytecode_bytes = bytecode.size(); + if (bytecode_bytes > adstack_sizer_bytecode_capacity_) { + std::size_t new_cap = std::max(bytecode_bytes, 2 * adstack_sizer_bytecode_capacity_); + Device::AllocParams params{}; + params.size = new_cap; + params.host_read = false; + params.host_write = false; + params.export_sharing = false; + params.usage = AllocUsage::Storage; + DeviceAllocation new_alloc; + RhiResult res = llvm_device()->allocate_memory(params, &new_alloc); + QD_ERROR_IF(res != RhiResult::success, + "Failed to allocate {} bytes for the adstack sizer bytecode scratch buffer (err: {})", params.size, + int(res)); + adstack_sizer_bytecode_alloc_ = std::make_unique(std::move(new_alloc)); + adstack_sizer_bytecode_capacity_ = new_cap; + } + void *bytecode_dev_ptr = get_device_alloc_info_ptr(*adstack_sizer_bytecode_alloc_); + copy_h2d(bytecode_dev_ptr, bytecode.data(), bytecode_bytes); + + // Invoke the device interpreter. On CUDA / AMDGPU `JITModule::call` launches this as a single-thread kernel + // on the default stream and stream-orders it before the subsequent main-kernel dispatch, so the writes we + // do here are visible by the time the user's kernel reads `adstack_max_sizes` etc. + // + // The sizer kernel dereferences `ctx->arg_buffer` on device (that's how it resolves `ExternalTensorRead` leaves + // against ndarray pointers the caller packed into the arg buffer). AMDGPU always stages a device-side copy of + // `RuntimeContext` because HIP has no UVA fallback and the host pointer faults with `hipErrorIllegalAddress`. + // CUDA stages the device copy only when the driver + kernel do not expose HMM / system-allocated memory (queried + // via `CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS`): CUDA UVA covers pinned / CUDA-managed memory only, not the + // plain `std::make_unique()` backing, so a host pointer works on HMM-capable setups but faults + // otherwise (Turing without HMM, Windows, pre-535 Linux drivers) as `CUDA_ERROR_ILLEGAL_ADDRESS` at the next DtoH + // sync `illegal memory access ... while calling memcpy_device_to_host`. When the caller passes `nullptr` + // (HMM-capable CUDA) we fall back to the host pointer; the launcher gates the allocation so HMM-equipped setups + // pay no staging cost. + auto *const runtime_jit = get_runtime_jit_module(); + void *runtime_context_ptr_for_sizer = + device_runtime_context_ptr != nullptr ? device_runtime_context_ptr : static_cast(&ctx->get_context()); + runtime_jit->call("runtime_eval_adstack_size_expr", llvm_runtime_, + runtime_context_ptr_for_sizer, bytecode_dev_ptr); + + // Read back the per-kind strides published by `runtime_eval_adstack_size_expr` so we can size the float and int + // heaps independently host-side. The combined stride is unused by the split-heap codegen but kept around for + // legacy-kernel backward compatibility (mirrors `stride_int` in the unconditional-split layout). + uint64_t stride_combined_readback = 0; + uint64_t stride_float_readback = 0; + uint64_t stride_int_readback = 0; + copy_d2h(&stride_combined_readback, runtime_adstack_stride_field_ptr_, sizeof(uint64_t)); + if (runtime_adstack_stride_float_field_ptr_ != nullptr) { + copy_d2h(&stride_float_readback, runtime_adstack_stride_float_field_ptr_, sizeof(uint64_t)); + } + if (runtime_adstack_stride_int_field_ptr_ != nullptr) { + copy_d2h(&stride_int_readback, runtime_adstack_stride_int_field_ptr_, sizeof(uint64_t)); + } + stride = static_cast(stride_combined_readback); + stride_float_bytes = static_cast(stride_float_readback); + stride_int_bytes = static_cast(stride_int_readback); + + // Record the cache entry so the next launch on this kernel can skip the sizer pipeline. We also + // need to read back the offsets / max_sizes arrays the sizer wrote to the device buffers - the + // cache hit path above republishes them, so we must store host copies here. n_stacks is small + // (a few dozen at most for any reasonable kernel) so the extra DtoH cost is negligible + // compared to the dispatch + sizer-kernel launch we are about to amortise away. + if (prog != nullptr) { + std::vector offsets_readback(n_stacks); + std::vector max_sizes_readback(n_stacks); + copy_d2h(offsets_readback.data(), offsets_dev_ptr, n_stacks * sizeof(uint64_t)); + copy_d2h(max_sizes_readback.data(), max_sizes_dev_ptr, n_stacks * sizeof(uint64_t)); + // Walk size_exprs structurally to gather the dependency keys (snode_ids referenced via + // FieldLoad, arg_ids referenced via ExternalTensorShape / ExternalTensorRead). Pure tree + // inspection - no live value reads, no nested kernel launches. Mirrors the SPIR-V analogue. + std::unordered_set snode_ids; + std::unordered_set arg_ids; + for (const auto &expr : ad_stack.size_exprs) { + for (const auto &node : expr.nodes) { + switch (static_cast(node.kind)) { + case SizeExpr::Kind::FieldLoad: + if (node.snode_id >= 0) + snode_ids.insert(node.snode_id); + break; + case SizeExpr::Kind::ExternalTensorShape: + case SizeExpr::Kind::ExternalTensorRead: + if (!node.arg_id_path.empty()) + arg_ids.insert(node.arg_id_path.front()); + break; + default: + break; + } + } + } + std::vector> snode_gens; + snode_gens.reserve(snode_ids.size()); + for (int snode_id : snode_ids) { + snode_gens.emplace_back(snode_id, prog->adstack_cache().snode_write_gen(snode_id)); + } + std::vector> arg_gens; + arg_gens.reserve(arg_ids.size()); + for (int arg_id : arg_ids) { + ArgArrayPtrKey data_key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; + auto ap_it = ctx->array_ptrs.find(data_key); + void *devalloc = (ap_it == ctx->array_ptrs.end()) ? nullptr : ap_it->second; + arg_gens.emplace_back(arg_id, devalloc, prog->adstack_cache().ndarray_data_gen(devalloc)); + } + prog->adstack_cache().record_llvm_per_task_ad_stack( + static_cast(&ad_stack), std::move(offsets_readback), std::move(max_sizes_readback), + stride_combined_readback, stride_float_readback, stride_int_readback, std::move(snode_gens), + std::move(arg_gens)); + } + } // end if (!llvm_metadata_cache_hit) + } + + // Legacy combined heap: not allocated. The unconditional-split codegen reads `heap_float` for f32 allocas and + // `heap_int` for i32 / u1 allocas; the legacy `adstack_heap_buffer` field is never dereferenced by freshly-compiled + // kernels. Skipping the allocation drops ~stride_int_bytes * num_threads of unused VRAM (multiple GB on heavy + // reverse-mode kernels on Nvidia / AMDGPU at saturating_grid_dim). + std::size_t needed_bytes = 0; + // Always allocate the int heap at `num_threads * stride_int_bytes` worst case. Int allocas are autodiff-emitted at + // the offload root unconditionally (loop-counter recovery, branch flags), so every dispatched thread reaches them and + // the eager `linear_tid * stride_int + int_offset` layout demands a row per thread. + if (stride_int_bytes > 0) { + const std::size_t int_bytes = stride_int_bytes * num_threads; + if (std::getenv("QD_DEBUG_ADSTACK")) { + std::fprintf(stderr, + "[adstack_heap] arch=llvm kind=I src=worst_case_num_threads num_threads=%zu stride=%zu " + "required_bytes=%zu (%.2f MB)\n", + num_threads, stride_int_bytes, int_bytes, double(int_bytes) / (1024.0 * 1024.0)); + std::fflush(stderr); + } + ensure_adstack_heap_int(int_bytes); + } + // Float heap: deferred to `ensure_per_task_float_heap_post_reducer` for tasks with a captured `bound_expr` (the + // reducer-published count drives the sizing); for non-bound_expr tasks size at `num_threads * stride_float_bytes` + // worst case here. The eager float path uses `linear_tid` as the row index so every dispatched thread needs backing + // storage; only the bound_expr path can shrink to `count * stride_float_bytes`. + if (stride_float_bytes > 0 && !ad_stack.bound_expr.has_value()) { + const std::size_t float_bytes = stride_float_bytes * num_threads; + if (std::getenv("QD_DEBUG_ADSTACK")) { + std::fprintf(stderr, + "[adstack_heap] arch=llvm kind=F src=worst_case_num_threads_no_bound_expr num_threads=%zu " + "stride=%zu required_bytes=%zu (%.2f MB)\n", + num_threads, stride_float_bytes, float_bytes, double(float_bytes) / (1024.0 * 1024.0)); + std::fflush(stderr); + } + ensure_adstack_heap_float(float_bytes); + } + last_published_stride_float_bytes_ = stride_float_bytes; + return needed_bytes; +} + +} // namespace quadrants::lang diff --git a/quadrants/runtime/llvm/adstack_lazy_claim/metadata_publish.h b/quadrants/runtime/llvm/adstack_lazy_claim/metadata_publish.h new file mode 100644 index 0000000000..57bf8fc913 --- /dev/null +++ b/quadrants/runtime/llvm/adstack_lazy_claim/metadata_publish.h @@ -0,0 +1,20 @@ +// Stage B of the LLVM sparse-adstack-heap lazy-claim pipeline: per-launch metadata publish. Encodes the +// per-task SizeExpr trees into device bytecode (or runs the host evaluator when every leaf is +// host-resolvable), launches the sizer runtime function on CUDA / AMDGPU when needed, packs and publishes +// `[stride_combined, stride_float, stride_int, offsets[n_stacks], max_sizes[n_stacks]]` through pinned-host +// async H2Ds on GPU or synchronous `memcpy` on CPU. The per-task metadata cache fast path republishes a hit's +// payload through the same pinned-async pipeline. Closes by invoking `ensure_adstack_heap_int` / +// `ensure_adstack_heap_float` (Stage C) for the stripes the codegen will read at row-claim time. +// +// All entry points are member methods of `LlvmRuntimeExecutor` and stay declared in +// `quadrants/runtime/llvm/llvm_runtime_executor.h`. This header is the place to add file-local helpers that +// future cross-stage callers might need; the stage's lambdas (`align_up_8`, `grow_to`, `copy_h2d`, +// `copy_d2h`, `publish_metadata_pinned_async`) all stay scoped inside `publish_adstack_metadata`. + +#pragma once + +namespace quadrants::lang { + +// Reserved for future cross-stage helper declarations. + +} // namespace quadrants::lang diff --git a/quadrants/runtime/llvm/llvm_adstack_lazy_claim.cpp b/quadrants/runtime/llvm/llvm_adstack_lazy_claim.cpp deleted file mode 100644 index 914e64d5ba..0000000000 --- a/quadrants/runtime/llvm/llvm_adstack_lazy_claim.cpp +++ /dev/null @@ -1,1232 +0,0 @@ -// Static-IR-bound sparse-adstack-heap reducer dispatch + lazy-claim buffer plumbing + split-heap grow-on-demand for -// LLVM backends (CPU / CUDA / AMDGPU). Extracted out of `llvm_runtime_executor.cpp` for the same reason the SPIR-V -// counterpart `quadrants/runtime/gfx/adstack_bound_reducer_launch.cpp` is - keep `LlvmRuntimeExecutor`'s body -// focused on runtime-init / SNode / kernel-launch plumbing that is not tied to the bound-reducer feature. -// -// Methods landing here all share the same triple of responsibilities, gated on the captured `bound_expr` field of -// `AdStackSizingInfo`: -// 1. Allocate / clear the per-task lazy-claim arrays (`adstack_row_counters[num_tasks]` for the LCA-block -// atomic-rmw target, `adstack_bound_row_capacities[num_tasks]` for the codegen-emitted bounds clamp). -// 2. Evaluate the captured `StaticAdStackBoundExpr` over `[0, length)` and publish the gate-passing count into -// the per-task capacity slot. CPU walks the gating field on the host directly; CUDA / AMDGPU dispatch a -// single-thread device-side reducer (`runtime_eval_static_bound_count` in `runtime_module/runtime.cpp`). -// 3. Size the float / int adstack heaps from the published count via `ensure_adstack_heap_float` / -// `ensure_adstack_heap_int` so each heap holds exactly `count * stride` bytes per dispatch instead of the -// dispatched-threads worst case. The split-heap field-of-LLVMRuntime addresses are cached on first grow by -// either `_float` or `_int` (the `runtime_get_adstack_split_heap_field_ptrs` getter returns all four in -// fixed slot order). -// -// All methods (and the two anonymous-namespace helpers) are conditional on at least one task in the kernel having -// a captured `bound_expr`; on kernels without one, or on the `cfg_optimization=False` cache-miss path that did not -// capture a gate, the methods early-return UINT32_MAX (capacity stays at the inert sentinel -// `publish_adstack_lazy_claim_buffers` wrote) and the dispatched-threads worst-case heap sizing remains in force. -// -// Caller responsibility (in `kernel_launcher.cpp` for each arch): invoke `publish_adstack_lazy_claim_buffers` once -// per kernel-launch before the first task dispatches, then per task call either `publish_per_task_bound_count_cpu` -// or `publish_per_task_bound_count_device` (arch-dispatched), then `ensure_per_task_float_heap_post_reducer`. Tasks -// without a captured `bound_expr` have those calls early-return. - -#include "quadrants/runtime/llvm/llvm_runtime_executor.h" -#include "quadrants/program/adstack_size_expr_eval.h" -#include "quadrants/program/program.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "quadrants/ir/static_adstack_bound_reducer_device.h" -#include "quadrants/ir/stmt_op_types.h" -#include "quadrants/ir/type_factory.h" -#include "quadrants/program/launch_context_builder.h" -#include "quadrants/program/program_impl.h" -#include "quadrants/rhi/llvm/llvm_device.h" - -#include "quadrants/platform/cuda/detect_cuda.h" -#include "quadrants/rhi/cuda/cuda_driver.h" -#if defined(QD_WITH_CUDA) -#include "quadrants/rhi/cuda/cuda_context.h" -#endif - -#include "quadrants/platform/amdgpu/detect_amdgpu.h" -#include "quadrants/rhi/amdgpu/amdgpu_driver.h" -#if defined(QD_WITH_AMDGPU) -#include "quadrants/rhi/amdgpu/amdgpu_context.h" -#endif - -namespace quadrants::lang { - -namespace { - -// Encode the captured `BinaryOpType` (stored as int in `cmp_op`) and evaluate against typed operands. Mirrors the -// SPIR-V reducer's `OpSwitch` over the same encoding. -template -inline bool eval_cmp(int cmp_op, T lhs, T rhs) { - switch (static_cast(cmp_op)) { - case BinaryOpType::cmp_lt: - return lhs < rhs; - case BinaryOpType::cmp_le: - return lhs <= rhs; - case BinaryOpType::cmp_gt: - return lhs > rhs; - case BinaryOpType::cmp_ge: - return lhs >= rhs; - case BinaryOpType::cmp_eq: - return lhs == rhs; - case BinaryOpType::cmp_ne: - return lhs != rhs; - default: - return false; - } -} - -// Encode the captured `BinaryOpType` into the 0-5 numeric range the LLVM device reducer's switch consumes. Mirrors the -// SPIR-V reducer's `encode_cmp_op` mapping at `quadrants/runtime/gfx/adstack_bound_reducer_launch.cpp`. -uint32_t encode_cmp_op_for_llvm_reducer(int captured_cmp_op) { - switch (static_cast(captured_cmp_op)) { - case BinaryOpType::cmp_lt: - return kLlvmReducerCmpLt; - case BinaryOpType::cmp_le: - return kLlvmReducerCmpLe; - case BinaryOpType::cmp_gt: - return kLlvmReducerCmpGt; - case BinaryOpType::cmp_ge: - return kLlvmReducerCmpGe; - case BinaryOpType::cmp_eq: - return kLlvmReducerCmpEq; - case BinaryOpType::cmp_ne: - return kLlvmReducerCmpNe; - default: - return std::numeric_limits::max(); - } -} - -} // namespace - -uint32_t LlvmRuntimeExecutor::publish_per_task_bound_count_cpu(std::size_t task_index, - const AdStackSizingInfo &ad_stack, - std::size_t length, - LaunchContextBuilder *ctx) { - // Default to UINT32_MAX (no clamp); only override on a successful host evaluation. The codegen-emitted bounds clamp - // at the float LCA-block claim site stays inert when the slot holds UINT32_MAX, so this fall-through is a no-op that - // preserves the existing behaviour. - if (config_.arch != Arch::x64 && config_.arch != Arch::arm64) { - return std::numeric_limits::max(); - } - if (!ad_stack.bound_expr.has_value()) { - return std::numeric_limits::max(); - } - const auto &be = ad_stack.bound_expr.value(); - - // Resolve the per-iteration field address. Two source kinds (mirrors the device-side reducer in - // `runtime_eval_static_bound_count`): - // * NdArray: walk `arg_buffer + data_ptr_byte_off` to fetch the ndarray's data pointer; the gating field - // is then `data_ptr[i]` for `i in [0, length)`. On CPU `arg_buffer` lives in host memory, so the deref is direct. - // * SNode: walk `runtime->roots[snode_root_id] + snode_byte_base_offset + i * snode_byte_cell_stride` - // for `i in [0, length)`. The byte offset / cell stride were resolved by the codegen-time SNode descriptor - // resolver (via `compile_snode_structs`); `runtime->roots` is host-resident on CPU and reachable through the - // `LLVMRuntime_get_roots` STRUCT_FIELD_ARRAY getter. - // Without the SNode arm, kernels with a captured SNode-backed bound_expr leave the capacity slot at UINT32_MAX (the - // `publish_adstack_lazy_claim_buffers` default), `ensure_per_task_float_heap_post_reducer` sizes the float heap at - // the worst-case num_threads count, and the codegen-emitted clamp goes inert -exactly the regression a `for i in - // selector: if selector[i] > eps:` SNode-gated reverse kernel hits when the float adstack heap can only hold - // `num_cpu_threads` rows but the LCA-block atomic-rmw fires once per gated iteration. - using FSK = StaticAdStackBoundExpr::FieldSourceKind; - if (be.field_source_kind != FSK::NdArray && be.field_source_kind != FSK::SNode) { - return std::numeric_limits::max(); - } - - const char *field_base = nullptr; - std::size_t field_stride_bytes = 0; - if (be.field_source_kind == FSK::NdArray) { - if (ctx == nullptr || ctx->args_type == nullptr || ctx->get_context().arg_buffer == nullptr) { - return std::numeric_limits::max(); - } - std::vector indices = be.ndarray_arg_id; - indices.push_back(TypeFactory::DATA_PTR_POS_IN_NDARRAY); - std::size_t data_ptr_byte_off = ctx->args_type->get_element_offset(indices); - const char *arg_buffer = static_cast(ctx->get_context().arg_buffer); - void *data_ptr = *reinterpret_cast(arg_buffer + data_ptr_byte_off); - if (data_ptr == nullptr) { - return std::numeric_limits::max(); - } - field_base = static_cast(data_ptr); - field_stride_bytes = be.field_dtype_is_double ? sizeof(double) : sizeof(int32_t); // f32 / i32 = 4 B, f64 = 8 B. - } else { - // SNode-backed source: query the host-resident `runtime->roots[snode_root_id]` pointer through the - // STRUCT_FIELD_ARRAY getter; on CPU this is an in-process call (no DtoH stage) and returns the dense root buffer - // base address directly. - if (be.snode_root_id < 0 || llvm_runtime_ == nullptr || result_buffer_cache_ == nullptr) { - return std::numeric_limits::max(); - } - // `RUNTIME_STRUCT_FIELD_ARRAY(LLVMRuntime, roots)` defines `runtime_LLVMRuntime_get_roots(LLVMRuntime *runtime, - // LLVMRuntime *s, int i)` (the macro takes a struct-of-interest argument distinct from the runtime context, but for - // fields of `LLVMRuntime` itself the two pointers are the same). `runtime_query` auto-prepends `llvm_runtime_` as - // the first arg, so we pass `(llvm_runtime_, root_id)` to make the call resolve to the 3-arg signature - // `(llvm_runtime_, llvm_runtime_, root_id)`. Mirrors the `node_allocators` call site a few hundred lines above. - void *root_ptr = - runtime_query("LLVMRuntime_get_roots", result_buffer_cache_, llvm_runtime_, be.snode_root_id); - if (root_ptr == nullptr) { - return std::numeric_limits::max(); - } - field_base = static_cast(root_ptr) + be.snode_byte_base_offset; - field_stride_bytes = static_cast(be.snode_byte_cell_stride); - } - - // Walk `[0, length)` evaluating the captured predicate on each thread's `field[i]`. The polarity bit selects - // enter-on-true vs enter-on-false at the LCA's IfStmt; the count we publish is always the number of threads that - // REACH the LCA, regardless of the gate orientation. f64 gates dispatch through the same float-source arm but read - // the source as `double*` and compare against `literal_f64` so the f64 precision the user declared is preserved - // end-to-end (narrowing the literal to f32 here would risk false-positive / negative counts on gates whose threshold - // sits within the f32 representable gap). - uint32_t count = 0; - if (be.field_dtype_is_float) { - if (be.field_dtype_is_double) { - for (std::size_t i = 0; i < length; ++i) { - const double v = *reinterpret_cast(field_base + i * field_stride_bytes); - const bool match = eval_cmp(be.cmp_op, v, be.literal_f64); - if (be.polarity ? match : !match) { - ++count; - } - } - } else { - for (std::size_t i = 0; i < length; ++i) { - const float v = *reinterpret_cast(field_base + i * field_stride_bytes); - const bool match = eval_cmp(be.cmp_op, v, be.literal_f32); - if (be.polarity ? match : !match) { - ++count; - } - } - } - } else { - for (std::size_t i = 0; i < length; ++i) { - const int32_t v = *reinterpret_cast(field_base + i * field_stride_bytes); - const bool match = eval_cmp(be.cmp_op, v, be.literal_i32); - if (be.polarity ? match : !match) { - ++count; - } - } - } - - // Publish the count into `runtime->adstack_bound_row_capacities[task_index]` so the codegen-emitted bounds clamp at - // the float LCA-block claim site reads it back as the per-task capacity. Slot was reset to UINT32_MAX by - // `publish_adstack_lazy_claim_buffers`; this overwrite tightens it to the real count. - if (runtime_adstack_bound_row_capacities_field_ptr_ == nullptr || adstack_bound_row_capacities_alloc_ == nullptr) { - return count; - } - void *bound_capacities_dev_ptr = get_device_alloc_info_ptr(*adstack_bound_row_capacities_alloc_); - // CPU only: write directly into the host-resident array. - uint32_t *slots = static_cast(bound_capacities_dev_ptr); - slots[task_index] = count; - return count; -} - -void LlvmRuntimeExecutor::publish_per_task_bound_count_device(std::size_t task_index, - const AdStackSizingInfo &ad_stack, - std::size_t length, - LaunchContextBuilder *ctx, - void *device_runtime_context_ptr) { - // Only fires for CUDA / AMDGPU; CPU goes through `publish_per_task_bound_count_cpu`. Bail when the task did not - // capture a bound_expr (no clamp needed - the slot stays at the UINT32_MAX default that - // `publish_adstack_lazy_claim_buffers` wrote). Both ndarray and SNode source kinds are dispatched through the same - // params blob; the device-side reducer selects between them via `field_source_is_snode`. - if (config_.arch != Arch::cuda && config_.arch != Arch::amdgpu) { - return; - } - if (!ad_stack.bound_expr.has_value()) { - return; - } - const auto &be = ad_stack.bound_expr.value(); - const bool is_snode_source = be.field_source_kind == StaticAdStackBoundExpr::FieldSourceKind::SNode; - if (ctx == nullptr || ctx->args_type == nullptr) { - return; - } - const uint32_t cmp_op_encoded = encode_cmp_op_for_llvm_reducer(be.cmp_op); - if (cmp_op_encoded == std::numeric_limits::max()) { - return; // unrecognised comparison op (the IR pattern matcher should have rejected it earlier) - } - - // Fill the device-side params struct on the host. Threshold bits live as the same u32 the runtime function bitcasts - // back; we copy whichever underlying integer or float value the analysis captured. The two source shapes (ndarray + - // SNode) share the comparison fields and differ only in which trailing fields the reducer reads (`arg_word_offset` - // for ndarray, `snode_root_id` + `snode_byte_*` for SNode); host-side we populate the matching pair and zero out the - // other. - LlvmAdStackBoundReducerDeviceParams params{}; - params.task_index = static_cast(task_index); - params.length = static_cast(is_snode_source ? be.snode_iter_count : length); - params.cmp_op = cmp_op_encoded; - params.field_dtype_is_float = be.field_dtype_is_float ? 1u : 0u; - params.field_dtype_is_double = be.field_dtype_is_double ? 1u : 0u; - params.polarity = be.polarity ? 1u : 0u; - if (be.field_dtype_is_double) { - // Pack the f64 threshold's 64-bit pattern into the (lo, hi) u32 pair the reducer reassembles. - uint64_t bits64 = 0; - std::memcpy(&bits64, &be.literal_f64, sizeof(uint64_t)); - params.threshold_bits = static_cast(bits64 & 0xFFFFFFFFu); - params.threshold_bits_high = static_cast(bits64 >> 32); - } else if (be.field_dtype_is_float) { - std::memcpy(¶ms.threshold_bits, &be.literal_f32, sizeof(uint32_t)); - } else { - params.threshold_bits = static_cast(be.literal_i32); - } - params.field_source_is_snode = is_snode_source ? 1u : 0u; - if (is_snode_source) { - params.arg_word_offset = 0; - params.snode_root_id = static_cast(be.snode_root_id); - params.snode_byte_base_offset = be.snode_byte_base_offset; - params.snode_byte_cell_stride = be.snode_byte_cell_stride; - } else { - // Resolve the ndarray data pointer's word offset within the kernel arg buffer. Same path the SPIR-V reducer and the - // CPU host-eval use; bytes -> words for the reducer's `arg_buffer_u32[arg_word_offset]` indexing. - std::vector indices = be.ndarray_arg_id; - indices.push_back(TypeFactory::DATA_PTR_POS_IN_NDARRAY); - std::size_t data_ptr_byte_off = ctx->args_type->get_element_offset(indices); - if (data_ptr_byte_off % sizeof(uint32_t) != 0) { - return; // misaligned offset; the reducer's u32-word indexing would lose bits. - } - params.arg_word_offset = static_cast(data_ptr_byte_off / sizeof(uint32_t)); - params.snode_root_id = 0; - params.snode_byte_base_offset = 0; - params.snode_byte_cell_stride = 0; - } - - // Lazy-allocate the device-side params scratch buffer the first time a bound_expr task fires; reuse for subsequent - // tasks across kernels. Sized for one struct (the reducer is single-task per call); a future optimisation could pack - // multiple tasks' params into one buffer and dispatch them in a single launch. - const std::size_t needed_bytes = sizeof(LlvmAdStackBoundReducerDeviceParams); - if (needed_bytes > adstack_bound_reducer_params_capacity_) { - Device::AllocParams alloc_params{}; - alloc_params.size = std::max(needed_bytes, 2 * adstack_bound_reducer_params_capacity_); - alloc_params.host_read = false; - alloc_params.host_write = true; - alloc_params.export_sharing = false; - alloc_params.usage = AllocUsage::Storage; - DeviceAllocation new_alloc; - RhiResult res = llvm_device()->allocate_memory(alloc_params, &new_alloc); - QD_ERROR_IF(res != RhiResult::success, - "Failed to allocate {} bytes for adstack bound reducer params buffer (err: {})", alloc_params.size, - int(res)); - adstack_bound_reducer_params_alloc_ = std::make_unique(std::move(new_alloc)); - adstack_bound_reducer_params_capacity_ = alloc_params.size; - } - void *params_dev_ptr = get_device_alloc_info_ptr(*adstack_bound_reducer_params_alloc_); - - // h2d the params struct into the device buffer. - if (config_.arch == Arch::cuda) { -#if defined(QD_WITH_CUDA) - CUDADriver::get_instance().memcpy_host_to_device(params_dev_ptr, ¶ms, needed_bytes); -#else - QD_NOT_IMPLEMENTED; -#endif - } else if (config_.arch == Arch::amdgpu) { -#if defined(QD_WITH_AMDGPU) - AMDGPUDriver::get_instance().memcpy_host_to_device(params_dev_ptr, ¶ms, needed_bytes); -#else - QD_NOT_IMPLEMENTED; -#endif - } - - // Dispatch the runtime reducer function: single-threaded device-side walk that reads `ctx->arg_buffer` (the - // device-mirror the launcher staged) and writes the count into `runtime->adstack_bound_row_capacities[task_index]`. - // Pass the device-side `RuntimeContext` pointer the same way the size-expr sizer does so the function can deref - // `ctx->arg_buffer` on-device. - auto *const runtime_jit = get_runtime_jit_module(); - void *runtime_context_ptr_for_reducer = - device_runtime_context_ptr != nullptr ? device_runtime_context_ptr : static_cast(&ctx->get_context()); - runtime_jit->call("runtime_eval_static_bound_count", llvm_runtime_, - runtime_context_ptr_for_reducer, params_dev_ptr); -} - -void LlvmRuntimeExecutor::ensure_adstack_heap_int(std::size_t needed_bytes) { - if (needed_bytes == 0 || needed_bytes <= adstack_heap_size_int_) { - return; - } - std::size_t new_size = std::max(needed_bytes, std::size_t(2) * adstack_heap_size_int_); - - Device::AllocParams params{}; - params.size = new_size; - params.host_read = false; - params.host_write = false; - params.export_sharing = false; - params.usage = AllocUsage::Storage; - DeviceAllocation new_alloc; - RhiResult res = llvm_device()->allocate_memory(params, &new_alloc); - QD_ERROR_IF(res != RhiResult::success, - "Failed to allocate {} bytes for the adstack int heap (err: {}). Consider lowering " - "`ad_stack_size` or the per-kernel reverse-mode adstack count.", - new_size, int(res)); - void *new_ptr = get_device_alloc_info_ptr(new_alloc); - auto new_guard = std::make_unique(std::move(new_alloc)); - - // The split-heap field-of-LLVMRuntime addresses are cached together by `ensure_adstack_heap_float` on its first grow - // (the same `runtime_get_adstack_split_heap_field_ptrs` getter returns all four addresses - float-buffer, float-size, - // int-buffer, int-size - in fixed slot order). On a fresh executor where this is the very first split-heap call, - // resolve the addresses here so we can publish independently of the float heap path. - if (runtime_adstack_heap_buffer_int_field_ptr_ == nullptr) { - auto *const runtime_jit = get_runtime_jit_module(); - runtime_jit->call("runtime_get_adstack_split_heap_field_ptrs", llvm_runtime_); - runtime_adstack_heap_buffer_float_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id, result_buffer_cache_)); - runtime_adstack_heap_size_float_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id + 1, result_buffer_cache_)); - runtime_adstack_heap_buffer_int_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id + 2, result_buffer_cache_)); - runtime_adstack_heap_size_int_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id + 3, result_buffer_cache_)); - } - uint64 size_u64 = static_cast(new_size); - if (config_.arch == Arch::cuda) { -#if defined(QD_WITH_CUDA) - CUDADriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_buffer_int_field_ptr_, &new_ptr, - sizeof(void *)); - CUDADriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_size_int_field_ptr_, &size_u64, - sizeof(uint64)); -#else - QD_NOT_IMPLEMENTED; -#endif - } else if (config_.arch == Arch::amdgpu) { -#if defined(QD_WITH_AMDGPU) - AMDGPUDriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_buffer_int_field_ptr_, &new_ptr, - sizeof(void *)); - AMDGPUDriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_size_int_field_ptr_, &size_u64, - sizeof(uint64)); -#else - QD_NOT_IMPLEMENTED; -#endif - } else { - *reinterpret_cast(runtime_adstack_heap_buffer_int_field_ptr_) = new_ptr; - *reinterpret_cast(runtime_adstack_heap_size_int_field_ptr_) = size_u64; - } - - adstack_heap_alloc_int_ = std::move(new_guard); - adstack_heap_size_int_ = new_size; -} - -void LlvmRuntimeExecutor::ensure_per_task_float_heap_post_reducer(std::size_t task_index, - const AdStackSizingInfo &ad_stack, - std::size_t num_threads, - LaunchContextBuilder *ctx) { - // Skip when the task has no float heap need (no f32 allocas, or analysis didn't capture a gate so we wouldn't have - // routed it through the lazy float path on the codegen side). - if (!ad_stack.bound_expr.has_value() || ad_stack.per_thread_stride_float == 0) { - return; - } - - // Read the per-task count the reducer published. On CPU the capacity buffer is host-resident; on CUDA / AMDGPU it's - // device memory and the read is a small (4-byte) DtoH per task. Cost is dominated by the actual main kernel. - uint32_t count = std::numeric_limits::max(); - if (adstack_bound_row_capacities_alloc_) { - void *capacities_dev_ptr = get_device_alloc_info_ptr(*adstack_bound_row_capacities_alloc_); - char *slot_ptr = static_cast(capacities_dev_ptr) + task_index * sizeof(uint32_t); - if (config_.arch == Arch::cuda) { -#if defined(QD_WITH_CUDA) - CUDADriver::get_instance().memcpy_device_to_host(&count, slot_ptr, sizeof(uint32_t)); -#else - QD_NOT_IMPLEMENTED; -#endif - } else if (config_.arch == Arch::amdgpu) { -#if defined(QD_WITH_AMDGPU) - AMDGPUDriver::get_instance().memcpy_device_to_host(&count, slot_ptr, sizeof(uint32_t)); -#else - QD_NOT_IMPLEMENTED; -#endif - } else { - count = *reinterpret_cast(slot_ptr); - } - } - - // Floor at 1 row when the captured count is zero (no thread passed the gate this launch). The codegen-emitted bounds - // clamp keeps `claimed_row` in [0, count-1] so threads that miss the gate never reach the LCA-block claim - the heap - // row stays unused. A 1-row allocation is cheap and keeps the heap pointer non-null. Clip by the captured - // compile-time loop trip count when known: each iteration claims at most one row at the LCA-block (one `atomic_add` - // per gating iteration), so the heap needs at most `loop_iter_static` rows regardless of how many cells of an - // oversized gating SNode the reducer counted. The analyzer leaves `loop_iter_static == 0` for runtime-bounded loops - // and for CPU LLVM tasks whose `[begin_value, end_value)` is a post-chunking subrange (the unclipped reducer count is - // the right upper bound there). - std::size_t effective_rows = - (count == std::numeric_limits::max()) ? num_threads : std::max(count, 1); - if (count != std::numeric_limits::max() && ad_stack.bound_expr.has_value()) { - // Shared with the SPIR-V launcher: see `clip_effective_rows_by_loop_trip_count` in - // `program/adstack_size_expr_eval.cpp`. LLVM dispatches one thread per loop iteration without the - // SPIR-V dispatch-cap-driven serialisation, so pass `numeric_limits::max()` to disable the - // dispatched-threads ceiling - any positive trip-count value is a sound upper bound on row claims - // here. `numeric_limits::max()` is the ceiling sentinel `clip_effective_rows_by_loop_trip_count` - // documents. - Program *prog = (program_impl_ != nullptr) ? program_impl_->program : nullptr; - clip_effective_rows_by_loop_trip_count(effective_rows, *ad_stack.bound_expr, - std::numeric_limits::max(), prog, ctx); - } - // The per-thread float stride (in bytes) was just published into `runtime->adstack_per_thread_stride_float` by the - // matching `publish_adstack_metadata` call earlier in this task's per-task block. We stash the value host-side so - // we can read it directly here instead of paying a sync DtoH on every bound_expr task. The launcher pairs publish - // + reducer + post-reducer per task with no intervening publish for another task, so the stash is accurate at this - // call site. `AdStackSizingInfo::per_thread_stride_float` from the analysis pre-pass is in entry-count units - // (`2 * max_size`), not bytes, and would massively undersize the heap. - uint64_t stride_float_bytes_u64 = static_cast(last_published_stride_float_bytes_); - const std::size_t needed_bytes = effective_rows * static_cast(stride_float_bytes_u64); - // `QD_DEBUG_ADSTACK=1` opt-in diagnostic. Persistent so memory regressions can be debugged without re-instrumenting. - if (std::getenv("QD_DEBUG_ADSTACK")) { - const char *src = (count == std::numeric_limits::max()) - ? "worst_case_num_threads" - : (count == 0 ? "reducer_zero_floored" : "reducer_count"); - std::fprintf(stderr, - "[adstack_heap] arch=llvm task_idx=%zu kind=F src=%s effective_rows=%zu stride=%llu " - "required_bytes=%zu (%.2f MB)\n", - task_index, src, effective_rows, static_cast(stride_float_bytes_u64), needed_bytes, - double(needed_bytes) / (1024.0 * 1024.0)); - std::fflush(stderr); - } - ensure_adstack_heap_float(needed_bytes); -} - -void LlvmRuntimeExecutor::publish_adstack_lazy_claim_buffers(std::size_t num_tasks) { - if (num_tasks == 0) { - return; - } - // Cache the field-of-LLVMRuntime addresses for the row counter / bound row capacity array pointers. Resolved once per - // program lifetime; subsequent grows write the new array pointers directly to the cached addresses. - if (runtime_adstack_row_counters_field_ptr_ == nullptr) { - auto *const runtime_jit = get_runtime_jit_module(); - runtime_jit->call("runtime_get_adstack_lazy_claim_field_ptrs", llvm_runtime_); - runtime_adstack_row_counters_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id, result_buffer_cache_)); - runtime_adstack_bound_row_capacities_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id + 1, result_buffer_cache_)); - } - - auto grow_to = [&](DeviceAllocationUnique &alloc, std::size_t capacity_u32) { - Device::AllocParams params{}; - params.size = capacity_u32 * sizeof(uint32_t); - params.host_read = false; - params.host_write = false; - params.export_sharing = false; - params.usage = AllocUsage::Storage; - DeviceAllocation new_alloc; - RhiResult res = llvm_device()->allocate_memory(params, &new_alloc); - QD_ERROR_IF(res != RhiResult::success, "Failed to allocate {} bytes for adstack lazy-claim array (err: {})", - params.size, int(res)); - alloc = std::make_unique(std::move(new_alloc)); - }; - - bool grew = false; - if (num_tasks > adstack_lazy_claim_capacity_) { - std::size_t new_cap = std::max(num_tasks, 2 * adstack_lazy_claim_capacity_); - grow_to(adstack_row_counters_alloc_, new_cap); - grow_to(adstack_bound_row_capacities_alloc_, new_cap); - adstack_lazy_claim_capacity_ = new_cap; - grew = true; - } - void *row_counters_dev_ptr = get_device_alloc_info_ptr(*adstack_row_counters_alloc_); - void *bound_capacities_dev_ptr = get_device_alloc_info_ptr(*adstack_bound_row_capacities_alloc_); - - // After every grow, publish the new array pointers into the runtime so the codegen-emitted GEPs - // (`runtime->adstack_row_counters[task_codegen_id]` and `runtime->adstack_bound_row_capacities[task_codegen_id]`) - // resolve against the live allocations. Skipped between grows because the cached field address holds the same pointer - // value. - auto copy_h2d = [&](void *dst, const void *src, std::size_t bytes) { - if (config_.arch == Arch::cuda) { -#if defined(QD_WITH_CUDA) - CUDADriver::get_instance().memcpy_host_to_device(dst, const_cast(src), bytes); -#else - QD_NOT_IMPLEMENTED; -#endif - } else if (config_.arch == Arch::amdgpu) { -#if defined(QD_WITH_AMDGPU) - AMDGPUDriver::get_instance().memcpy_host_to_device(dst, const_cast(src), bytes); -#else - QD_NOT_IMPLEMENTED; -#endif - } else { - std::memcpy(dst, src, bytes); - } - }; - if (grew) { - copy_h2d(runtime_adstack_row_counters_field_ptr_, &row_counters_dev_ptr, sizeof(void *)); - copy_h2d(runtime_adstack_bound_row_capacities_field_ptr_, &bound_capacities_dev_ptr, sizeof(void *)); - } - - // Per-launch reset: zero the counter slots (each task's LCA-block atomic-rmw add starts from 0 and accumulates its - // own claims) and write UINT32_MAX into the capacity slots so the codegen-emitted bounds clamp is inert unless a - // later reducer dispatch overrides slots with tighter counts. Memset rather than per-slot store: the host pays one - // O(num_tasks) buffer fill per kernel-launch, regardless of arch. - std::vector zero_buf(num_tasks, 0u); - std::vector uint_max_buf(num_tasks, std::numeric_limits::max()); - copy_h2d(row_counters_dev_ptr, zero_buf.data(), num_tasks * sizeof(uint32_t)); - copy_h2d(bound_capacities_dev_ptr, uint_max_buf.data(), num_tasks * sizeof(uint32_t)); -} - -void LlvmRuntimeExecutor::ensure_adstack_heap_float(std::size_t needed_bytes) { - if (needed_bytes == 0 || needed_bytes <= adstack_heap_size_float_) { - return; - } - // Mirror `ensure_adstack_heap`'s amortised-doubling growth and grow-on-demand semantics. The float heap is allocated - // independently from the combined heap so a kernel with bound_expr tasks can shrink the combined slice to int-only - // while still backing float allocas at `row_id_var * stride_float + float_offset`. - std::size_t new_size = std::max(needed_bytes, std::size_t(2) * adstack_heap_size_float_); - - Device::AllocParams params{}; - params.size = new_size; - params.host_read = false; - params.host_write = false; - params.export_sharing = false; - params.usage = AllocUsage::Storage; - DeviceAllocation new_alloc; - RhiResult res = llvm_device()->allocate_memory(params, &new_alloc); - QD_ERROR_IF(res != RhiResult::success, - "Failed to allocate {} bytes for the adstack float heap (err: {}). Consider lowering " - "`ad_stack_size` or the per-kernel reverse-mode adstack count.", - new_size, int(res)); - void *new_ptr = get_device_alloc_info_ptr(new_alloc); - auto new_guard = std::make_unique(std::move(new_alloc)); - - // Resolve and cache the field-of-LLVMRuntime addresses for the split-heap fields on first grow. The - // `runtime_get_adstack_split_heap_field_ptrs` helper returns four addresses in fixed slot order: float-buffer-ptr, - // float-size, int-buffer-ptr, int-size. We only consume the float pair here; the int half is reserved for a future - // symmetric `ensure_adstack_heap_int` if it becomes useful (today the int allocas in bound_expr tasks ride the - // combined heap with a smaller stride). - if (runtime_adstack_heap_buffer_float_field_ptr_ == nullptr) { - auto *const runtime_jit = get_runtime_jit_module(); - runtime_jit->call("runtime_get_adstack_split_heap_field_ptrs", llvm_runtime_); - runtime_adstack_heap_buffer_float_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id, result_buffer_cache_)); - runtime_adstack_heap_size_float_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id + 1, result_buffer_cache_)); - runtime_adstack_heap_buffer_int_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id + 2, result_buffer_cache_)); - runtime_adstack_heap_size_int_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id + 3, result_buffer_cache_)); - } - uint64 size_u64 = static_cast(new_size); - if (config_.arch == Arch::cuda) { -#if defined(QD_WITH_CUDA) - CUDADriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_buffer_float_field_ptr_, &new_ptr, - sizeof(void *)); - CUDADriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_size_float_field_ptr_, &size_u64, - sizeof(uint64)); -#else - QD_NOT_IMPLEMENTED; -#endif - } else if (config_.arch == Arch::amdgpu) { -#if defined(QD_WITH_AMDGPU) - AMDGPUDriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_buffer_float_field_ptr_, &new_ptr, - sizeof(void *)); - AMDGPUDriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_size_float_field_ptr_, &size_u64, - sizeof(uint64)); -#else - QD_NOT_IMPLEMENTED; -#endif - } else { - *reinterpret_cast(runtime_adstack_heap_buffer_float_field_ptr_) = new_ptr; - *reinterpret_cast(runtime_adstack_heap_size_float_field_ptr_) = size_u64; - } - - adstack_heap_alloc_float_ = std::move(new_guard); - adstack_heap_size_float_ = new_size; -} - -void LlvmRuntimeExecutor::check_adstack_overflow() { - // Called from `synchronize()` on every sync, plus other Quadrants Python entry points wired in - // `Program::check_adstack_overflow_and_raise`. The flag lives in pinned host memory (allocated at - // `materialize_runtime`); polling is a relaxed atomic exchange on the cached host pointer via - // `std::atomic` reinterpret_cast - no DtoH, no JIT call, no sync drain. Available on all backends because - // the pinned-host memory is in the host process address space regardless of where the kernel that wrote it ran. - // The reinterpret_cast is portable because `std::atomic` is layout-compatible with `int64_t` on every - // target (verified by the static_assert below); see also Itanium ABI / MSVC ABI lock-free guarantees. - // - // Returns early when the slot has not been allocated yet (e.g. a C++ test that constructs Program without - // materializing the runtime and then triggers `Program::finalize -> synchronize`). - static_assert(std::atomic::is_always_lock_free, - "std::atomic must be lock-free for the reinterpret_cast pattern below to be portable"); - if (adstack_overflow_flag_host_ptr_ == nullptr) { - return; - } - int64_t flag = - reinterpret_cast *>(adstack_overflow_flag_host_ptr_)->exchange(0, std::memory_order_relaxed); - if (flag == 0) { - return; - } - // Drain the companion task-id slot in the same poll. Both slots cleared so the next overflow records a fresh - // identity. `task_id == 0` means the kernel that overflowed pre-dates the registry wiring or its - // `ad_stack.registry_id` was unset for any reason (e.g. a deserialised offline-cache task that has not yet been - // re-registered); the diagnose helper falls through to the generic dual-cause message in that case. - uint32_t task_id = 0; - if (adstack_overflow_task_id_host_ptr_ != nullptr) { - int64_t recorded = reinterpret_cast *>(adstack_overflow_task_id_host_ptr_) - ->exchange(0, std::memory_order_relaxed); - task_id = static_cast(recorded); - } - Program *prog = (program_impl_ != nullptr) ? program_impl_->program : nullptr; - std::string diagnostic; - if (prog != nullptr) { - auto diag = prog->adstack_cache().diagnose_adstack_overflow(task_id); - diagnostic = std::move(diag.message); - // Auto-invalidate the per-task metadata caches when the synchronous sizer rerun confirmed the cache is stale - // (DLPack-bypass cause). The current run is corrupted (we are about to raise), but the next launch's sizer - // reruns from scratch against the live (mutated) state and the kernel runs to completion without further - // user intervention. Unknown / Quadrants-bug cases skip the invalidation so a real sizer bug is not masked - // by silent recompute. - if (diag.confirmed_invalid_cache) { - prog->adstack_cache().invalidate_all_per_task(); - } - } else { - diagnostic = - "Adstack overflow: a reverse-mode autodiff kernel pushed more elements than the adstack capacity " - "allows."; - } - throw QuadrantsAssertionError( - "Adstack overflow: a reverse-mode autodiff kernel pushed more elements " - "than the adstack capacity allows. Raised at the next Quadrants Python " - "entry rather than at the offending kernel launch.\n" + - diagnostic); -} - -std::size_t LlvmRuntimeExecutor::publish_adstack_metadata(const AdStackSizingInfo &ad_stack, - std::size_t num_threads, - LaunchContextBuilder *ctx, - void *device_runtime_context_ptr) { - const auto n_stacks = ad_stack.allocas.size(); - if (n_stacks == 0 || num_threads == 0) { - return 0; - } - auto align_up_8 = [](std::size_t n) -> std::size_t { return (n + 7u) & ~std::size_t{7u}; }; - // Allocate / grow the two device-side metadata arrays. Capacity is in u64 entries, kept at or above n_stacks. - // On GPU these buffers are written exclusively by the device-side sizer kernel (`runtime_eval_adstack_size_expr`); - // on CPU the host evaluator writes them directly via `std::memcpy`. Either way the pointers published into - // `runtime->adstack_offsets` / `adstack_max_sizes` stay stable across launches unless we grow here. - auto grow_to = [&](DeviceAllocationUnique &alloc, std::size_t capacity_u64) { - Device::AllocParams params{}; - params.size = capacity_u64 * sizeof(uint64_t); - params.host_read = false; - params.host_write = false; - params.export_sharing = false; - params.usage = AllocUsage::Storage; - DeviceAllocation new_alloc; - RhiResult res = llvm_device()->allocate_memory(params, &new_alloc); - QD_ERROR_IF(res != RhiResult::success, "Failed to allocate {} bytes for adstack metadata array (err: {})", - params.size, int(res)); - alloc = std::make_unique(std::move(new_alloc)); - }; - if (n_stacks > adstack_metadata_capacity_) { - std::size_t new_cap = std::max(n_stacks, 2 * adstack_metadata_capacity_); - grow_to(adstack_offsets_alloc_, new_cap); - grow_to(adstack_max_sizes_alloc_, new_cap); - adstack_metadata_capacity_ = new_cap; - } - void *offsets_dev_ptr = get_device_alloc_info_ptr(*adstack_offsets_alloc_); - void *max_sizes_dev_ptr = get_device_alloc_info_ptr(*adstack_max_sizes_alloc_); - - auto copy_h2d = [&](void *dst, const void *src, std::size_t bytes) { - if (config_.arch == Arch::cuda) { -#if defined(QD_WITH_CUDA) - CUDADriver::get_instance().memcpy_host_to_device(dst, const_cast(src), bytes); -#else - QD_NOT_IMPLEMENTED; -#endif - } else if (config_.arch == Arch::amdgpu) { -#if defined(QD_WITH_AMDGPU) - AMDGPUDriver::get_instance().memcpy_host_to_device(dst, const_cast(src), bytes); -#else - QD_NOT_IMPLEMENTED; -#endif - } else { - std::memcpy(dst, src, bytes); - } - }; - auto copy_d2h = [&](void *dst, const void *src, std::size_t bytes) { - if (config_.arch == Arch::cuda) { -#if defined(QD_WITH_CUDA) - CUDADriver::get_instance().memcpy_device_to_host(dst, const_cast(src), bytes); -#else - QD_NOT_IMPLEMENTED; -#endif - } else if (config_.arch == Arch::amdgpu) { -#if defined(QD_WITH_AMDGPU) - AMDGPUDriver::get_instance().memcpy_device_to_host(dst, const_cast(src), bytes); -#else - QD_NOT_IMPLEMENTED; -#endif - } else { - std::memcpy(dst, src, bytes); - } - }; - - // Cache the runtime-field addresses on the first call; then publish the metadata-array pointers into the - // runtime struct. The stride field is written by the sizer on GPU and by this function on CPU, so we cache the - // address either way. - if (runtime_adstack_stride_field_ptr_ == nullptr) { - auto *const runtime_jit = get_runtime_jit_module(); - runtime_jit->call("runtime_get_adstack_metadata_field_ptrs", llvm_runtime_); - // Slot order: combined-stride, offsets, max_sizes, float-stride, int-stride. Slots 0/1/2 keep the legacy ordering - // for code paths that have not migrated to the split layout; slots 3/4 are new. - runtime_adstack_stride_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id, result_buffer_cache_)); - runtime_adstack_offsets_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id + 1, result_buffer_cache_)); - runtime_adstack_max_sizes_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id + 2, result_buffer_cache_)); - runtime_adstack_stride_float_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id + 3, result_buffer_cache_)); - runtime_adstack_stride_int_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id + 4, result_buffer_cache_)); - } - // The pointed-to scratch allocations are stable across launches (only `grow_to` swaps them). Skip the per-launch - // h2d that publishes the pointer values whenever they have not changed since the last call. On HIP / CUDA each - // skipped pointer-publish is one queue round-trip the launcher would otherwise pay; on a typical reverse-mode - // sweep this fires thousands of times. - if (offsets_dev_ptr != adstack_offsets_dev_ptr_published_) { - copy_h2d(runtime_adstack_offsets_field_ptr_, &offsets_dev_ptr, sizeof(void *)); - adstack_offsets_dev_ptr_published_ = offsets_dev_ptr; - } - if (max_sizes_dev_ptr != adstack_max_sizes_dev_ptr_published_) { - copy_h2d(runtime_adstack_max_sizes_field_ptr_, &max_sizes_dev_ptr, sizeof(void *)); - adstack_max_sizes_dev_ptr_published_ = max_sizes_dev_ptr; - } - - std::size_t stride = 0; - const bool is_gpu_llvm = (config_.arch == Arch::cuda || config_.arch == Arch::amdgpu); - - // Shared GPU async publish helper: pack `[stride_combined, stride_float, stride_int, offsets[n_stacks], - // max_sizes[n_stacks]]` into the pinned-host scratch (grow on demand, double-amortised), then issue 5 async H2Ds - // on the active stream and record the completion event. Used by both the host-eval branch (CUDA / AMDGPU - // resolvable size_exprs) and the on-device-sizer cache-hit branch. The driver's H2D DMA reads from the pinned - // bytes at execution time, so a `wait_pending()` at the top of the next call defends against an unusual - // interleaving where the GPU queue is backlogged and the next launch enters before the previous launch's last - // copy has been consumed. Only callable when `is_gpu_llvm` is true. - auto publish_metadata_pinned_async = [&](const uint64_t *offsets_src, const uint64_t *max_sizes_src, - uint64_t stride_combined_u64, uint64_t stride_float_u64, - uint64_t stride_int_u64) { - const std::size_t header_bytes = 3 * sizeof(uint64_t); - const std::size_t array_bytes = n_stacks * sizeof(uint64_t); - const std::size_t total_bytes = header_bytes + 2 * array_bytes; - auto wait_pending = [this]() { - if (!pinned_metadata_event_pending_) { - return; - } -#if defined(QD_WITH_CUDA) - if (config_.arch == Arch::cuda) { - CUDADriver::get_instance().event_synchronize(pinned_metadata_event_); - } -#endif -#if defined(QD_WITH_AMDGPU) - if (config_.arch == Arch::amdgpu) { - AMDGPUDriver::get_instance().event_synchronize(pinned_metadata_event_); - } -#endif - pinned_metadata_event_pending_ = false; - }; - if (total_bytes > pinned_metadata_scratch_capacity_) { - wait_pending(); - if (pinned_metadata_scratch_ != nullptr) { -#if defined(QD_WITH_CUDA) - if (config_.arch == Arch::cuda) { - CUDADriver::get_instance().mem_free_host(pinned_metadata_scratch_); - } -#endif -#if defined(QD_WITH_AMDGPU) - if (config_.arch == Arch::amdgpu) { - AMDGPUDriver::get_instance().mem_free_host(pinned_metadata_scratch_); - } -#endif - pinned_metadata_scratch_ = nullptr; - } - const std::size_t new_capacity = std::max(total_bytes, 2 * pinned_metadata_scratch_capacity_); -#if defined(QD_WITH_CUDA) - if (config_.arch == Arch::cuda) { - CUDADriver::get_instance().mem_alloc_host(&pinned_metadata_scratch_, new_capacity); - } -#endif -#if defined(QD_WITH_AMDGPU) - if (config_.arch == Arch::amdgpu) { - // `hipHostMallocDefault == 0`. Coherent / portable / write-combined flags are intentionally not set; the - // workload is small payloads written linearly by the host and DMA-read by the GPU once. - AMDGPUDriver::get_instance().mem_alloc_host(&pinned_metadata_scratch_, new_capacity, 0u); - } -#endif - pinned_metadata_scratch_capacity_ = new_capacity; - } - if (pinned_metadata_event_ == nullptr) { - // `cuEventCreate` flag `0` (CU_EVENT_DEFAULT) means timing-enabled, which the driver costs us nothing to set - // up here and lets future profilers attach without re-creating the event. `hipEventCreateWithFlags` takes - // the same encoding. -#if defined(QD_WITH_CUDA) - if (config_.arch == Arch::cuda) { - CUDADriver::get_instance().event_create(&pinned_metadata_event_, 0u); - } -#endif -#if defined(QD_WITH_AMDGPU) - if (config_.arch == Arch::amdgpu) { - AMDGPUDriver::get_instance().event_create(&pinned_metadata_event_, 0u); - } -#endif - } - wait_pending(); - auto *pinned = static_cast(pinned_metadata_scratch_); - pinned[0] = stride_combined_u64; - pinned[1] = stride_float_u64; - pinned[2] = stride_int_u64; - std::memcpy(pinned + 3, offsets_src, array_bytes); - std::memcpy(pinned + 3 + n_stacks, max_sizes_src, array_bytes); - // Queue the metadata copies on the stream the subsequent main-kernel dispatch will run on, so the GPU - // stream-orders the copies before the kernel reads `adstack_max_sizes` etc. CUDA: `CUDAContext::get_stream()` - // (configurable via `set_stream`, defaults to the null stream); AMDGPU: always the default stream because - // `AMDGPUContext::launch` passes `nullptr` to `hipLaunchKernel`. -#if defined(QD_WITH_CUDA) - if (config_.arch == Arch::cuda) { - void *active_stream = CUDAContext::get_instance().get_stream(); - CUDADriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_field_ptr_, pinned, - sizeof(uint64_t), active_stream); - if (runtime_adstack_stride_float_field_ptr_ != nullptr) { - CUDADriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_float_field_ptr_, pinned + 1, - sizeof(uint64_t), active_stream); - } - if (runtime_adstack_stride_int_field_ptr_ != nullptr) { - CUDADriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_int_field_ptr_, pinned + 2, - sizeof(uint64_t), active_stream); - } - CUDADriver::get_instance().memcpy_host_to_device_async(offsets_dev_ptr, pinned + 3, array_bytes, active_stream); - CUDADriver::get_instance().memcpy_host_to_device_async(max_sizes_dev_ptr, pinned + 3 + n_stacks, array_bytes, - active_stream); - CUDADriver::get_instance().event_record(pinned_metadata_event_, active_stream); - } -#endif -#if defined(QD_WITH_AMDGPU) - if (config_.arch == Arch::amdgpu) { - void *active_stream = nullptr; - AMDGPUDriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_field_ptr_, pinned, - sizeof(uint64_t), active_stream); - if (runtime_adstack_stride_float_field_ptr_ != nullptr) { - AMDGPUDriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_float_field_ptr_, pinned + 1, - sizeof(uint64_t), active_stream); - } - if (runtime_adstack_stride_int_field_ptr_ != nullptr) { - AMDGPUDriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_int_field_ptr_, pinned + 2, - sizeof(uint64_t), active_stream); - } - AMDGPUDriver::get_instance().memcpy_host_to_device_async(offsets_dev_ptr, pinned + 3, array_bytes, active_stream); - AMDGPUDriver::get_instance().memcpy_host_to_device_async(max_sizes_dev_ptr, pinned + 3 + n_stacks, array_bytes, - active_stream); - AMDGPUDriver::get_instance().event_record(pinned_metadata_event_, active_stream); - } -#endif - pinned_metadata_event_pending_ = true; - }; - - // Host-eval fast path. The on-device sizer kernel exists to handle one specific leaf, `ExternalTensorRead`, - // whose ndarray data lives in GPU-private memory (`cudaMalloc` / `hipMalloc`, no UVA fallback) and thus - // cannot be touched from the host. Every other SizeExpr leaf - `Const`, `BoundVariable`, - // `ExternalTensorShape`, `FieldLoad` - is host-resolvable through the existing `evaluate_adstack_size_expr` - // path, so when the kernel's SizeExprs are all `ExternalTensorRead`-free we can skip the encode + bytecode - // h2d + sizer-kernel launch + d2h-stride pipeline entirely and write the metadata directly via `copy_h2d`. - // On CUDA the saved `cuMemcpyDtoH` for the per-launch stride readback is the dominant cost: every reverse- - // mode kernel launch in a 100-substep test paid one such synchronous DtoH each, and that compound stall - // accounted for the bulk of the GPU launch overhead under adstack mode. The condition is computed once per - // launch by scanning each stack's `nodes` vector for an `ExternalTensorRead` leaf; the scan is O(total - // SizeExpr nodes), well below the cost of the cheapest h2d / d2h on any LLVM GPU backend. - bool all_size_exprs_host_resolvable = true; - for (std::size_t i = 0; i < n_stacks && all_size_exprs_host_resolvable; ++i) { - if (i >= ad_stack.size_exprs.size()) { - continue; - } - for (const auto &node : ad_stack.size_exprs[i].nodes) { - if (static_cast(node.kind) == SizeExpr::Kind::ExternalTensorRead) { - all_size_exprs_host_resolvable = false; - break; - } - } - } - const bool use_host_eval = !is_gpu_llvm || all_size_exprs_host_resolvable; - // Per-kind byte strides resolved either host-side (host-eval branch) or by reading back from the device runtime - // struct after the sizer kernel ran (GPU branch). Used below to size the float / int heaps independently for the - // unconditional split-heap layout. - std::size_t stride_float_bytes = 0; - std::size_t stride_int_bytes = 0; - if (use_host_eval) { - // CPU + GPU-without-ExternalTensorRead path: run the host evaluator directly. On CPU we use synchronous - // `copy_h2d` (just `std::memcpy` for that arch), but on CUDA / AMDGPU we ship the same payload through - // pinned-host memory via async `cuMemcpyHtoDAsync` / `hipMemcpyHtoDAsync` so the host returns immediately - // after queueing the copies on the default stream and the subsequent main-kernel launch (also on the - // default stream) stream-orders after the copies. The synchronous `cuMemcpyHtoD_v2` path used to block - // the host on every one of the three writes we issue per launch; with thousands of reverse-mode launches - // per `test_differentiable_rigid` run, those serial host stalls were a measurable fraction of wallclock. - // `FieldLoad` is serviced by `SNodeRwAccessorsBank` regardless of arch. - // Guard `program_impl_->program` lookups against the C++-only-tests setup where `program_impl_` itself is null; - // the on-device branch below already does this and falls back to `max_size_compile_time`. - Program *prog = (program_impl_ != nullptr) ? program_impl_->program : nullptr; - // Span the per-stack `evaluate_adstack_size_expr` calls below with one shared read cache. - SizeExprLaunchScope launch_scope; - std::vector host_max_sizes(n_stacks); - for (std::size_t i = 0; i < n_stacks; ++i) { - const SerializedSizeExpr *expr = (i < ad_stack.size_exprs.size()) ? &ad_stack.size_exprs[i] : nullptr; - int64_t v = -1; - if (expr != nullptr && !expr->nodes.empty() && prog != nullptr) { - v = evaluate_adstack_size_expr(*expr, prog, ctx); - } - if (v < 0) { - v = static_cast(ad_stack.allocas[i].max_size_compile_time); - } - host_max_sizes[i] = static_cast(std::max(v, 1)); - } - // Unconditional split-heap layout: float allocas live at `host_offsets[i]` within the float-only slice (addressed - // on the codegen side as `heap_float + row_id_var * stride_float + float_offset` for bound_expr tasks, or - // `heap_float + linear_tid * stride_float + float_offset` for non-bound_expr tasks); int allocas live at - // `host_offsets[i]` within the int-only slice (addressed as `heap_int + linear_tid * stride_int + int_offset`). - // Same scheme regardless of `bound_expr` so the heap layout matches the SPIR-V backend's unconditional split into - // `BufferType::AdStackHeapFloat` + `AdStackHeapInt`. The legacy combined-heap path is no longer used by the - // codegen; the combined stride / heap fields stay in the LLVMRuntime struct only as a transitional fallback for - // offline-cache-loaded kernels that predate the split, and the published `adstack_per_thread_stride` mirrors - // `stride_int` so any such kernel sees the smaller int-only stride. - std::vector host_offsets(n_stacks); - for (std::size_t i = 0; i < n_stacks; ++i) { - const std::size_t step = align_up_8(sizeof(int64_t) + ad_stack.allocas[i].entry_size_bytes * host_max_sizes[i]); - const bool is_float = ad_stack.allocas[i].heap_kind == AdStackAllocaInfo::HeapKind::Float; - host_offsets[i] = is_float ? stride_float_bytes : stride_int_bytes; - if (is_float) { - stride_float_bytes += step; - } else { - stride_int_bytes += step; - } - } - stride = stride_int_bytes; - uint64_t stride_combined_u64 = static_cast(stride); - uint64_t stride_float_u64 = static_cast(stride_float_bytes); - uint64_t stride_int_u64 = static_cast(stride_int_bytes); - if (!is_gpu_llvm) { - copy_h2d(offsets_dev_ptr, host_offsets.data(), n_stacks * sizeof(uint64_t)); - copy_h2d(max_sizes_dev_ptr, host_max_sizes.data(), n_stacks * sizeof(uint64_t)); - copy_h2d(runtime_adstack_stride_field_ptr_, &stride_combined_u64, sizeof(uint64_t)); - // Per-kind strides used by the split-heap codegen path; harmless when the codegen has not migrated yet (the - // kernel reads only the combined stride). Skipped when the cache is empty (first launch on a stale executor - // instance where `runtime_get_adstack_metadata_field_ptrs` populated only the legacy slots; the null check is - // defensive - any host writing to `nullptr` would crash with no diagnostic). - if (runtime_adstack_stride_float_field_ptr_ != nullptr) { - copy_h2d(runtime_adstack_stride_float_field_ptr_, &stride_float_u64, sizeof(uint64_t)); - } - if (runtime_adstack_stride_int_field_ptr_ != nullptr) { - copy_h2d(runtime_adstack_stride_int_field_ptr_, &stride_int_u64, sizeof(uint64_t)); - } - } else { - publish_metadata_pinned_async(host_offsets.data(), host_max_sizes.data(), stride_combined_u64, stride_float_u64, - stride_int_u64); - } - } else { - // GPU (CUDA / AMDGPU): encode the SizeExpr trees into device bytecode, upload, launch the sizer runtime - // function, read back just the computed stride. The sizer kernel writes `adstack_max_sizes[]`, - // `adstack_offsets[]`, and `adstack_per_thread_stride` directly into the runtime struct and the metadata - // arrays above - no further host-writes to those fields are needed this launch. - // - // Why this architecture rather than host-eval: on CUDA / AMDGPU the ndarray data lives in GPU-private memory - // (plain `cudaMalloc` / `hipMalloc`, not managed / unified), so the host evaluator's `ExternalTensorRead` - // deref reads garbage. Moving the interpreter on-device keeps the pointer semantics intact - it reads the - // data pointer out of `ctx->arg_buffer` (which the kernel will read too) and dereferences it where the - // memory lives, with no migration / readback of the ndarray payload itself. - // - // Per-task metadata cache fast path: the sizer kernel's output (offsets / max_sizes / strides) is a - // deterministic function of (a) the per-task `AdStackSizingInfo *` (compile-time bytecode shape, stable - // for the kernel's lifetime), (b) every SNode value a `FieldLoad` leaf reads, and (c) every ndarray - // value an `ExternalTensorRead` leaf reads. Each launcher (cpu / cuda / amdgpu) bumps - // `Program::snode_write_gen_` / `ndarray_data_gen_` for everything this kernel may mutate before - // calling here, so the per-source generation snapshots stored alongside the cached payload catch any - // input change between launches and force a fresh sizer dispatch when needed. On hit, the cached - // offsets / max_sizes / strides are republished into the runtime struct via the same `copy_h2d` paths - // the host-eval branch above uses, and the entire bytecode-encode + h2d + sizer-kernel launch + - // 3x DtoH-stride pipeline is skipped. The cost of the sizer dispatch + DtoH stalls is small per - // launch on CUDA / AMDGPU, but a long sequence of reverse-mode launches over the same kernel - // pays it once per launch; the cache amortises that to once per generation-bump. - Program *prog = (program_impl_ != nullptr) ? program_impl_->program : nullptr; - bool llvm_metadata_cache_hit = false; - if (prog != nullptr) { - AdStackCache::LlvmPerTaskAdStackCacheEntry entry; - if (prog->adstack_cache().try_llvm_per_task_ad_stack_cache_hit(static_cast(&ad_stack), ctx, - entry)) { - QD_ASSERT(entry.offsets.size() == n_stacks && entry.max_sizes.size() == n_stacks); - // Publish the cached payload through the pinned-host async pipeline shared with the host-eval - // branch above: one pinned-scratch pack + five `memcpy_host_to_device_async` issued on the same - // stream the main kernel will dispatch on, ordered behind the previous launch's - // `pinned_metadata_event_pending_` wait. Packing the same `[stride_combined, stride_float, - // stride_int, offsets[n_stacks], max_sizes[n_stacks]]` shape keeps both branches' DMA pattern - // identical and removes the per-launch sync round-trips a `copy_h2d` would otherwise impose; on - // CPU `copy_h2d` is `memcpy` already so we keep the direct path there. - if (!is_gpu_llvm) { - copy_h2d(offsets_dev_ptr, entry.offsets.data(), n_stacks * sizeof(uint64_t)); - copy_h2d(max_sizes_dev_ptr, entry.max_sizes.data(), n_stacks * sizeof(uint64_t)); - copy_h2d(runtime_adstack_stride_field_ptr_, &entry.stride_combined, sizeof(uint64_t)); - if (runtime_adstack_stride_float_field_ptr_ != nullptr) { - copy_h2d(runtime_adstack_stride_float_field_ptr_, &entry.stride_float, sizeof(uint64_t)); - } - if (runtime_adstack_stride_int_field_ptr_ != nullptr) { - copy_h2d(runtime_adstack_stride_int_field_ptr_, &entry.stride_int, sizeof(uint64_t)); - } - } else { - publish_metadata_pinned_async(entry.offsets.data(), entry.max_sizes.data(), entry.stride_combined, - entry.stride_float, entry.stride_int); - } - stride = static_cast(entry.stride_combined); - stride_float_bytes = static_cast(entry.stride_float); - stride_int_bytes = static_cast(entry.stride_int); - llvm_metadata_cache_hit = true; - } - } - if (!llvm_metadata_cache_hit) { - std::vector bytecode; - if (program_impl_ != nullptr && program_impl_->program != nullptr) { - bytecode = encode_adstack_size_expr_device_bytecode(ad_stack, program_impl_->program, ctx); - } else { - // No program attached (rare: C++-only tests that construct Program without a full runtime). Fall through - // to compile-time bounds by emitting an empty-tree bytecode - the device interpreter sees - // `root_node_idx == -1` for every stack and routes to `max_size_compile_time`. - bytecode = encode_adstack_size_expr_device_bytecode(ad_stack, nullptr, ctx); - } - // Grow the scratch buffer if the bytecode outgrew the cached capacity. Amortised doubling keeps the - // allocation traffic O(log max_bytecode_bytes) across a run. - const std::size_t bytecode_bytes = bytecode.size(); - if (bytecode_bytes > adstack_sizer_bytecode_capacity_) { - std::size_t new_cap = std::max(bytecode_bytes, 2 * adstack_sizer_bytecode_capacity_); - Device::AllocParams params{}; - params.size = new_cap; - params.host_read = false; - params.host_write = false; - params.export_sharing = false; - params.usage = AllocUsage::Storage; - DeviceAllocation new_alloc; - RhiResult res = llvm_device()->allocate_memory(params, &new_alloc); - QD_ERROR_IF(res != RhiResult::success, - "Failed to allocate {} bytes for the adstack sizer bytecode scratch buffer (err: {})", params.size, - int(res)); - adstack_sizer_bytecode_alloc_ = std::make_unique(std::move(new_alloc)); - adstack_sizer_bytecode_capacity_ = new_cap; - } - void *bytecode_dev_ptr = get_device_alloc_info_ptr(*adstack_sizer_bytecode_alloc_); - copy_h2d(bytecode_dev_ptr, bytecode.data(), bytecode_bytes); - - // Invoke the device interpreter. On CUDA / AMDGPU `JITModule::call` launches this as a single-thread kernel - // on the default stream and stream-orders it before the subsequent main-kernel dispatch, so the writes we - // do here are visible by the time the user's kernel reads `adstack_max_sizes` etc. - // - // The sizer kernel dereferences `ctx->arg_buffer` on device (that's how it resolves `ExternalTensorRead` leaves - // against ndarray pointers the caller packed into the arg buffer). AMDGPU always stages a device-side copy of - // `RuntimeContext` because HIP has no UVA fallback and the host pointer faults with `hipErrorIllegalAddress`. - // CUDA stages the device copy only when the driver + kernel do not expose HMM / system-allocated memory (queried - // via `CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS`): CUDA UVA covers pinned / CUDA-managed memory only, not the - // plain `std::make_unique()` backing, so a host pointer works on HMM-capable setups but faults - // otherwise (Turing without HMM, Windows, pre-535 Linux drivers) as `CUDA_ERROR_ILLEGAL_ADDRESS` at the next DtoH - // sync `illegal memory access ... while calling memcpy_device_to_host`. When the caller passes `nullptr` - // (HMM-capable CUDA) we fall back to the host pointer; the launcher gates the allocation so HMM-equipped setups - // pay no staging cost. - auto *const runtime_jit = get_runtime_jit_module(); - void *runtime_context_ptr_for_sizer = - device_runtime_context_ptr != nullptr ? device_runtime_context_ptr : static_cast(&ctx->get_context()); - runtime_jit->call("runtime_eval_adstack_size_expr", llvm_runtime_, - runtime_context_ptr_for_sizer, bytecode_dev_ptr); - - // Read back the per-kind strides published by `runtime_eval_adstack_size_expr` so we can size the float and int - // heaps independently host-side. The combined stride is unused by the split-heap codegen but kept around for - // legacy-kernel backward compatibility (mirrors `stride_int` in the unconditional-split layout). - uint64_t stride_combined_readback = 0; - uint64_t stride_float_readback = 0; - uint64_t stride_int_readback = 0; - copy_d2h(&stride_combined_readback, runtime_adstack_stride_field_ptr_, sizeof(uint64_t)); - if (runtime_adstack_stride_float_field_ptr_ != nullptr) { - copy_d2h(&stride_float_readback, runtime_adstack_stride_float_field_ptr_, sizeof(uint64_t)); - } - if (runtime_adstack_stride_int_field_ptr_ != nullptr) { - copy_d2h(&stride_int_readback, runtime_adstack_stride_int_field_ptr_, sizeof(uint64_t)); - } - stride = static_cast(stride_combined_readback); - stride_float_bytes = static_cast(stride_float_readback); - stride_int_bytes = static_cast(stride_int_readback); - - // Record the cache entry so the next launch on this kernel can skip the sizer pipeline. We also - // need to read back the offsets / max_sizes arrays the sizer wrote to the device buffers - the - // cache hit path above republishes them, so we must store host copies here. n_stacks is small - // (a few dozen at most for any reasonable kernel) so the extra DtoH cost is negligible - // compared to the dispatch + sizer-kernel launch we are about to amortise away. - if (prog != nullptr) { - std::vector offsets_readback(n_stacks); - std::vector max_sizes_readback(n_stacks); - copy_d2h(offsets_readback.data(), offsets_dev_ptr, n_stacks * sizeof(uint64_t)); - copy_d2h(max_sizes_readback.data(), max_sizes_dev_ptr, n_stacks * sizeof(uint64_t)); - // Walk size_exprs structurally to gather the dependency keys (snode_ids referenced via - // FieldLoad, arg_ids referenced via ExternalTensorShape / ExternalTensorRead). Pure tree - // inspection - no live value reads, no nested kernel launches. Mirrors the SPIR-V analogue. - std::unordered_set snode_ids; - std::unordered_set arg_ids; - for (const auto &expr : ad_stack.size_exprs) { - for (const auto &node : expr.nodes) { - switch (static_cast(node.kind)) { - case SizeExpr::Kind::FieldLoad: - if (node.snode_id >= 0) - snode_ids.insert(node.snode_id); - break; - case SizeExpr::Kind::ExternalTensorShape: - case SizeExpr::Kind::ExternalTensorRead: - if (!node.arg_id_path.empty()) - arg_ids.insert(node.arg_id_path.front()); - break; - default: - break; - } - } - } - std::vector> snode_gens; - snode_gens.reserve(snode_ids.size()); - for (int snode_id : snode_ids) { - snode_gens.emplace_back(snode_id, prog->adstack_cache().snode_write_gen(snode_id)); - } - std::vector> arg_gens; - arg_gens.reserve(arg_ids.size()); - for (int arg_id : arg_ids) { - ArgArrayPtrKey data_key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; - auto ap_it = ctx->array_ptrs.find(data_key); - void *devalloc = (ap_it == ctx->array_ptrs.end()) ? nullptr : ap_it->second; - arg_gens.emplace_back(arg_id, devalloc, prog->adstack_cache().ndarray_data_gen(devalloc)); - } - prog->adstack_cache().record_llvm_per_task_ad_stack( - static_cast(&ad_stack), std::move(offsets_readback), std::move(max_sizes_readback), - stride_combined_readback, stride_float_readback, stride_int_readback, std::move(snode_gens), - std::move(arg_gens)); - } - } // end if (!llvm_metadata_cache_hit) - } - - // Legacy combined heap: not allocated. The unconditional-split codegen reads `heap_float` for f32 allocas and - // `heap_int` for i32 / u1 allocas; the legacy `adstack_heap_buffer` field is never dereferenced by freshly-compiled - // kernels. Skipping the allocation drops ~stride_int_bytes * num_threads of unused VRAM (multiple GB on heavy - // reverse-mode kernels on Nvidia / AMDGPU at saturating_grid_dim). - std::size_t needed_bytes = 0; - // Always allocate the int heap at `num_threads * stride_int_bytes` worst case. Int allocas are autodiff-emitted at - // the offload root unconditionally (loop-counter recovery, branch flags), so every dispatched thread reaches them and - // the eager `linear_tid * stride_int + int_offset` layout demands a row per thread. - if (stride_int_bytes > 0) { - const std::size_t int_bytes = stride_int_bytes * num_threads; - if (std::getenv("QD_DEBUG_ADSTACK")) { - std::fprintf(stderr, - "[adstack_heap] arch=llvm kind=I src=worst_case_num_threads num_threads=%zu stride=%zu " - "required_bytes=%zu (%.2f MB)\n", - num_threads, stride_int_bytes, int_bytes, double(int_bytes) / (1024.0 * 1024.0)); - std::fflush(stderr); - } - ensure_adstack_heap_int(int_bytes); - } - // Float heap: deferred to `ensure_per_task_float_heap_post_reducer` for tasks with a captured `bound_expr` (the - // reducer-published count drives the sizing); for non-bound_expr tasks size at `num_threads * stride_float_bytes` - // worst case here. The eager float path uses `linear_tid` as the row index so every dispatched thread needs backing - // storage; only the bound_expr path can shrink to `count * stride_float_bytes`. - if (stride_float_bytes > 0 && !ad_stack.bound_expr.has_value()) { - const std::size_t float_bytes = stride_float_bytes * num_threads; - if (std::getenv("QD_DEBUG_ADSTACK")) { - std::fprintf(stderr, - "[adstack_heap] arch=llvm kind=F src=worst_case_num_threads_no_bound_expr num_threads=%zu " - "stride=%zu required_bytes=%zu (%.2f MB)\n", - num_threads, stride_float_bytes, float_bytes, double(float_bytes) / (1024.0 * 1024.0)); - std::fflush(stderr); - } - ensure_adstack_heap_float(float_bytes); - } - last_published_stride_float_bytes_ = stride_float_bytes; - return needed_bytes; -} - -} // namespace quadrants::lang diff --git a/quadrants/runtime/llvm/llvm_runtime_executor.h b/quadrants/runtime/llvm/llvm_runtime_executor.h index 623f711f79..b642bb232d 100644 --- a/quadrants/runtime/llvm/llvm_runtime_executor.h +++ b/quadrants/runtime/llvm/llvm_runtime_executor.h @@ -1,7 +1,9 @@ #pragma once #include +#include #include +#include #ifdef QD_WITH_LLVM @@ -181,6 +183,25 @@ class LlvmRuntimeExecutor { std::size_t num_threads, LaunchContextBuilder *ctx); + // Max-reducer dispatch on LLVM. For each captured `StaticAdStackMaxReducerSpec` across every task in `tasks`, hits + // `AdStackCache::try_max_reducer_cache_hit` first; on miss h2d-copies the params blob + body bytecode and invokes + // `runtime_eval_adstack_max_reduce` via the runtime JIT. Single dispatch path covers CPU (host call), CUDA, and + // AMDGPU. The returned map is keyed by `(registry_id, stack_id, mor_node_idx)` packed via the same encoding the gfx + // variant uses, so `substitute_precomputed_max_over_range` works backend-agnostically. Caller invokes this BEFORE the + // per-task `publish_adstack_metadata` loop and passes the result map down to each per-task `publish` call so the + // encoder substitutes captured `MaxOverRange`s before walking the tree. `MaxReducerResultMap` is defined in + // `quadrants/program/adstack_size_expr_eval.h`; declared inline here to avoid pulling that header into every + // translation unit that includes `llvm_runtime_executor.h`. + std::unordered_map dispatch_max_reducers_for_tasks(const std::vector &ad_stacks, + LaunchContextBuilder *ctx, + void *device_runtime_context_ptr); + // Convenience overload that extracts each task's `ad_stack` and forwards to the primary entry point. Lets the CUDA / + // AMDGPU per-arch launchers call into the dispatcher with the `OffloadedTask` list they already hold, without each + // launcher copy-pasting the per-task `ad_stack` extraction loop. + std::unordered_map dispatch_max_reducers_for_tasks(const std::vector &tasks, + LaunchContextBuilder *ctx, + void *device_runtime_context_ptr); + // Return (and lazily cache) the device pointer to `runtime->temporaries`, the global temporary buffer backing // `GlobalTemporaryStmt` loads and stores. GPU kernel launchers use this to read back dynamic range_for bounds (begin // / end i32 values at known byte offsets) via a host-side DtoH memcpy when sizing the adstack heap. Cached because @@ -324,6 +345,17 @@ class LlvmRuntimeExecutor { // allocation. void *runtime_adstack_row_counters_field_ptr_{nullptr}; void *runtime_adstack_bound_row_capacities_field_ptr_{nullptr}; + // Cached address of `LLVMRuntime::adstack_max_reducer_outputs` (a `i64 *` field). Resolved once per program lifetime + // via `runtime_get_adstack_max_reducer_field_ptr`; the per-launch dispatch writes the (possibly grown) device buffer + // pointer to this address so `runtime_eval_adstack_max_reduce` deref's the live allocation. + void *runtime_adstack_max_reducer_outputs_field_ptr_{nullptr}; + + // Per-launch transient: the `MaxReducerResultMap` populated by `dispatch_max_reducers_for_tasks` and read by + // `publish_adstack_metadata`. Owned by the executor across the per-task publish loop within a single kernel launch; + // cleared at the top of every `dispatch_max_reducers_for_tasks` call so a kernel without captured specs sees an empty + // map. Keeping the map on the executor avoids threading it through `publish_adstack_metadata`'s call sites in three + // per-arch launchers. + std::unordered_map current_max_reducer_results_; // Host-owned storage for the per-kernel lazy-claim arrays: `adstack_row_counters_alloc_`: u32[num_tasks] atomic // counter the codegen-emitted LCA-block row claim atomic-rmws @@ -370,6 +402,21 @@ class LlvmRuntimeExecutor { DeviceAllocationUnique adstack_sizer_bytecode_alloc_ = nullptr; std::size_t adstack_sizer_bytecode_capacity_{0}; + // Per-launch scratch buffers for the max-reducer dispatch. One holds a single `LlvmAdStackMaxReducerDeviceParams` + // blob per call (the runtime function is dispatched per spec); the other holds the body bytecode (concatenated + // `AdStackSizeExprDeviceNode` array followed by indices). Both grow amortised-doubling and are reused across specs + // within a launch and across launches. Unused on CPU when the runtime function is invoked directly host-side without + // staging. + DeviceAllocationUnique adstack_max_reducer_params_alloc_ = nullptr; + std::size_t adstack_max_reducer_params_capacity_{0}; + DeviceAllocationUnique adstack_max_reducer_bytecode_alloc_ = nullptr; + std::size_t adstack_max_reducer_bytecode_capacity_{0}; + // Per-launch output buffer the runtime function writes into (`runtime->adstack_max_reducer_outputs[output_slot]` = + // i64 dispatched value). Sized to fit the kernel's spec count; grown amortised-doubling. Backed by the runtime + // module's `adstack_max_reducer_outputs` field via `runtime_LLVMRuntime_set_adstack_max_reducer_outputs`. + DeviceAllocationUnique adstack_max_reducer_outputs_alloc_ = nullptr; + std::size_t adstack_max_reducer_outputs_capacity_{0}; + // Pinned (page-locked) host scratch + completion event used by the host-eval branch of `publish_adstack_metadata` on // CUDA / AMDGPU to issue the per-launch adstack metadata writes asynchronously. With pageable host memory // `cuMemcpyHtoDAsync` synchronises on the staging copy, so the source MUST be pinned for the async copy to be truly diff --git a/quadrants/runtime/llvm/runtime_module/runtime.cpp b/quadrants/runtime/llvm/runtime_module/runtime.cpp index 3063022ade..165631ab74 100644 --- a/quadrants/runtime/llvm/runtime_module/runtime.cpp +++ b/quadrants/runtime/llvm/runtime_module/runtime.cpp @@ -26,6 +26,7 @@ #include "quadrants/inc/cuda_kernel_utils.inc.h" #include "quadrants/ir/adstack_size_expr_device.h" #include "quadrants/ir/static_adstack_bound_reducer_device.h" +#include "quadrants/ir/static_adstack_max_reducer_device.h" #include "quadrants/math/arithmetic.h" struct RuntimeContext; @@ -647,6 +648,15 @@ struct LLVMRuntime { u32 *adstack_bound_row_capacities = nullptr; u64 adstack_bound_row_capacities_capacity = 0; + // Per-spec output slot for the max reducer. One i64 per captured `StaticAdStackMaxReducerSpec`, written by + // `runtime_eval_adstack_max_reduce` during the per-launch dispatch and read by the host launcher to substitute the + // value as a `Const` into the per-stack `SerializedSizeExpr` tree before any LLVM eval path walks it. Sized / grown + // by the LlvmRuntimeExecutor lazy-allocate path on the first launch that has captured specs; cleared to INT64_MIN + // before each dispatch so the running-max sentinel is well-defined when `length == 0` (returns INT64_MIN; the caller + // floors at 0 + clamps to compile-time). + i64 *adstack_max_reducer_outputs = nullptr; + u64 adstack_max_reducer_outputs_capacity = 0; + Ptr result_buffer; i32 allocator_lock; @@ -693,6 +703,7 @@ STRUCT_FIELD(LLVMRuntime, adstack_offsets); STRUCT_FIELD(LLVMRuntime, adstack_max_sizes); STRUCT_FIELD(LLVMRuntime, adstack_row_counters); STRUCT_FIELD(LLVMRuntime, adstack_bound_row_capacities); +STRUCT_FIELD(LLVMRuntime, adstack_max_reducer_outputs); STRUCT_FIELD(LLVMRuntime, adstack_overflow_flag_dev_ptr); STRUCT_FIELD(LLVMRuntime, adstack_overflow_task_id_dev_ptr); @@ -840,6 +851,14 @@ void runtime_get_adstack_lazy_claim_field_ptrs(LLVMRuntime *runtime) { runtime->set_result(quadrants_result_buffer_ret_value_id + 1, (u64)(void *)&runtime->adstack_bound_row_capacities); } +// Companion to `runtime_get_adstack_lazy_claim_field_ptrs` for the max-reducer outputs. The output buffer is +// per-launch-allocated host-side and the field-address is cached once so the per-launch publish only writes the new +// array pointer (when the buffer grows) and the read-back per-spec slot reads through the runtime's stable address. +// Single field, single result slot. +void runtime_get_adstack_max_reducer_field_ptr(LLVMRuntime *runtime) { + runtime->set_result(quadrants_result_buffer_ret_value_id, (u64)(void *)&runtime->adstack_max_reducer_outputs); +} + // Device-resident adstack SizeExpr interpreter. Runs on whatever backend the LLVM runtime JIT-compiles this // bitcode to: a plain C function call on CPU, a single-thread kernel launch on CUDA / AMDGPU. The bytecode buffer // layout is defined by `quadrants/ir/adstack_size_expr_device.h` and produced host-side by @@ -905,7 +924,8 @@ i64 device_load_element(const char *data_ptr, i64 linear, i32 prim_dt) { } } -i64 device_eval_node(const quadrants::lang::AdStackSizeExprDeviceNode *nodes, +i64 device_eval_node(LLVMRuntime *runtime, + const quadrants::lang::AdStackSizeExprDeviceNode *nodes, const i32 *indices, i32 node_idx, DeviceEvalScope *scope, @@ -916,42 +936,43 @@ i64 device_eval_node(const quadrants::lang::AdStackSizeExprDeviceNode *nodes, case K::kConst: return node.const_value; case K::kAdd: - return device_eval_node(nodes, indices, node.operand_a, scope, arg_buffer) + - device_eval_node(nodes, indices, node.operand_b, scope, arg_buffer); + return device_eval_node(runtime, nodes, indices, node.operand_a, scope, arg_buffer) + + device_eval_node(runtime, nodes, indices, node.operand_b, scope, arg_buffer); case K::kSub: { - // Match the host evaluator: clamp negative trip counts to zero so an underflowed `end - begin` doesn't - // poison a surrounding `Mul` / `MaxOverRange` product. - i64 lhs = device_eval_node(nodes, indices, node.operand_a, scope, arg_buffer); - i64 rhs = device_eval_node(nodes, indices, node.operand_b, scope, arg_buffer); + // Match the host evaluator: clamp negative trip counts to zero so an underflowed `end - begin` doesn't poison a + // surrounding `Mul` / `MaxOverRange` product. + i64 lhs = device_eval_node(runtime, nodes, indices, node.operand_a, scope, arg_buffer); + i64 rhs = device_eval_node(runtime, nodes, indices, node.operand_b, scope, arg_buffer); i64 diff = lhs - rhs; return diff > 0 ? diff : 0; } case K::kMul: - return device_eval_node(nodes, indices, node.operand_a, scope, arg_buffer) * - device_eval_node(nodes, indices, node.operand_b, scope, arg_buffer); + return device_eval_node(runtime, nodes, indices, node.operand_a, scope, arg_buffer) * + device_eval_node(runtime, nodes, indices, node.operand_b, scope, arg_buffer); case K::kMax: { - i64 lhs = device_eval_node(nodes, indices, node.operand_a, scope, arg_buffer); - i64 rhs = device_eval_node(nodes, indices, node.operand_b, scope, arg_buffer); + i64 lhs = device_eval_node(runtime, nodes, indices, node.operand_a, scope, arg_buffer); + i64 rhs = device_eval_node(runtime, nodes, indices, node.operand_b, scope, arg_buffer); return lhs > rhs ? lhs : rhs; } case K::kMaxOverRange: { - i64 begin = device_eval_node(nodes, indices, node.operand_a, scope, arg_buffer); - i64 end = device_eval_node(nodes, indices, node.operand_b, scope, arg_buffer); - // Mirror of the host evaluator's iteration guard (see `adstack_size_expr_eval.cpp::evaluate_node`). - // A range of several million would stall the sizer launch for seconds; anything that wide is almost - // certainly a pre-pass bug. Hard-stop via quadrants_assert so the failure surfaces at qd.sync() with - // a clear adstack-sizer attribution rather than a mysterious launch hang. + i64 begin = device_eval_node(runtime, nodes, indices, node.operand_a, scope, arg_buffer); + i64 end = device_eval_node(runtime, nodes, indices, node.operand_b, scope, arg_buffer); + // Iteration guard. Recognized `MaxOverRange` shapes are dispatched in parallel by the max-reducer and substituted + // to a `Const` before the sizer interpreter walks the tree, so the only way to land in this branch with a delta + // above the cap is an out-of-grammar shape. Skip the walk and return 0 to keep the single-thread on-device + // dispatch within the driver's TDR window; the host's `evaluate_node` re-runs the same tree synchronously during + // the diagnose path and raises via its `QD_ERROR_IF` then. constexpr i64 kMaxOverRangeIterations = i64{1} << 24; + if (end > begin && end - begin > kMaxOverRangeIterations) { + return 0; + } i64 result = 0; const i32 var = node.var_id; for (i64 i = begin; i < end; ++i) { - if (i - begin > kMaxOverRangeIterations) { - break; // see host evaluator's note; a sibling assertion in the host path will have fired first. - } if (var >= 0 && var < kDeviceBoundVarCap) { scope->values[var] = i; } - i64 v = device_eval_node(nodes, indices, node.body_node_idx, scope, arg_buffer); + i64 v = device_eval_node(runtime, nodes, indices, node.body_node_idx, scope, arg_buffer); if (v > result) result = v; } @@ -964,17 +985,17 @@ i64 device_eval_node(const quadrants::lang::AdStackSizeExprDeviceNode *nodes, return 0; } case K::kExternalTensorRead: { - // `data_ptr_slot = *(void **)(arg_buffer + arg_buffer_offset)`: read the ndarray's data pointer out of the - // kernel arg buffer at the offset the host encoder precomputed via `args_type->get_element_offset`. This - // replaces the host evaluator's `ctx->array_ptrs` map lookup with a straight field read that the device - // can perform without reaching for a std::unordered_map. + // `data_ptr_slot = *(void **)(arg_buffer + arg_buffer_offset)`: read the ndarray's data pointer out of the kernel + // arg buffer at the offset the host encoder precomputed via `args_type->get_element_offset`. This replaces the + // host evaluator's `ctx->array_ptrs` map lookup with a straight field read that the device can perform without + // reaching for a std::unordered_map. auto data_ptr_raw = *reinterpret_cast(arg_buffer + node.arg_buffer_offset); - // Indices encoded as `[idx_a_raw, elem_stride_a]` pairs per axis, matching `kFieldLoad`'s layout. The - // host encoder in `adstack_size_expr_eval.cpp` pre-computes the C-order element strides from the - // launch context's ndarray shape; a 1-D read collapses to `elem_stride = 1` and recovers the original - // stride-1 sum. The multi-axis case is what this fix unblocks: without the per-axis multiply a 2-D - // `a[i, j]` read would land on `a_flat[i + j]` instead of `a_flat[i * shape[1] + j]`, silently - // under-bounding the sizer and tripping `Adstack overflow` at `qd.sync()`. + // Indices encoded as `[idx_a_raw, elem_stride_a]` pairs per axis, matching `kFieldLoad`'s layout. The host + // encoder in `adstack_size_expr_eval.cpp` pre-computes the C-order element strides from the launch context's + // ndarray shape; a 1-D read collapses to `elem_stride = 1` and recovers the original stride-1 sum. The multi-axis + // case is what this fix unblocks: without the per-axis multiply a 2-D `a[i, j]` read would land on + // `a_flat[i + j]` instead of `a_flat[i * shape[1] + j]`, under-bounding the sizer and tripping `Adstack overflow` + // at `qd.sync()`. i64 linear = 0; for (i32 k = 0; k < node.indices_count; ++k) { const i32 raw = indices[node.indices_offset + 2 * k]; @@ -991,13 +1012,37 @@ i64 device_eval_node(const quadrants::lang::AdStackSizeExprDeviceNode *nodes, } return device_load_element(data_ptr_raw, linear, node.prim_dt); } - case K::kFieldLoad: - // The LLVM encoder always host-folds `FieldLoad` leaves (via `SNodeRwAccessorsBank`) before emitting - // device bytecode, so the interpreter never sees `kFieldLoad`. It is reserved for the SPIR-V sizer - // shader's PSB read path. Return zero rather than asserting (this runtime-module compiles to LLVM - // bitcode with no host-assert facility) so a mis-emitted tree surfaces downstream as a wrong-`max_size` - // adstack overflow at `qd.sync()` rather than silently UB here. - return 0; + case K::kFieldLoad: { + // Bound-var-indexed `kFieldLoad` body leaf: the encoder stores `arg_buffer_offset = snode_root_id` and + // `const_value = place_byte_offset_in_root` (i.e. the byte offset of the place leaf within its containing snode + // tree). The base pointer is `runtime->roots[snode_root_id]`, which lives on every LLVM backend (CPU host pointer + // / CUDA / AMDGPU device pointer set up at materialization time). The closed-FieldLoad path host-folds at encode + // time and never reaches this arm; a `snode_root_id < 0` here means the bytecode came from a SPIR-V encoder + // (which stores `root_psb + place_byte_offset` directly in `const_value` and leaves `arg_buffer_offset = -1`), + // not the LLVM path; we cannot resolve it here so return 0 (safe over-approximation - a sentinel max forces the + // host to fall back to capped sizer eval downstream). + const i32 snode_root_id = node.arg_buffer_offset; + if (snode_root_id < 0 || runtime == nullptr) { + return 0; + } + const auto root_ptr = reinterpret_cast(runtime->roots[snode_root_id]); + const i64 place_byte_off = node.const_value; + i64 elem_idx = 0; + for (i32 k = 0; k < node.indices_count; ++k) { + const i32 raw = indices[node.indices_offset + 2 * k]; + const i32 elem_stride = indices[node.indices_offset + 2 * k + 1]; + i64 v = 0; + if (raw >= 0) { + v = raw; + } else { + const i32 var = -(raw + 1); + if (var >= 0 && var < kDeviceBoundVarCap) + v = scope->values[var]; + } + elem_idx += v * static_cast(elem_stride); + } + return device_load_element(root_ptr + place_byte_off, elem_idx, node.prim_dt); + } } return 0; } @@ -1169,6 +1214,76 @@ void runtime_eval_static_bound_count(LLVMRuntime *runtime, RuntimeContext *ctx, runtime->adstack_bound_row_capacities[params->task_index] = count; } +// per-launch per-launch parallel-max evaluator over the body of a captured `StaticAdStackMaxReducerSpec`'s +// `MaxOverRange` node. Single-thread serial walk on every backend (CPU host thread, CUDA / AMDGPU single-thread +// JIT-launched device function), mirroring `runtime_eval_static_bound_count`. The body bytecode reuses the existing +// `AdStackSizeExprDeviceNode` POD format already shared between the host encoder and the LLVM device sizer interpreter +// (`device_eval_node`); The recognizer grammar restricts body kinds to `kConst / kBoundVariable / kExternalTensorRead / +// kAdd / kSub / kMul / kMax`, so the recursive walk never recurses through `kMaxOverRange` or `kFieldLoad` and the +// iteration stays linear in the cross-product of every captured axis. +// +// Multi-axis: walks the cross-product of `params->per_axis_length[0..num_axes)` outermost-first. Per-iteration the +// runtime pre-populates `scope.values[per_axis_var_id[a]] = per_axis_begin[a] + axis_idx_a` for every axis, then +// evaluates `device_eval_node(body_root_idx, &scope, ...)` and updates the running max. Result is written to +// `runtime->adstack_max_reducer_outputs[output_slot]`. Caller clears the slot before the dispatch so an empty range +// leaves the sentinel for the host launcher to detect and floor at zero. +void runtime_eval_adstack_max_reduce(LLVMRuntime *runtime, RuntimeContext *ctx, Ptr params_blob, Ptr body_bytecode) { + using quadrants::lang::AdStackSizeExprDeviceNode; + using quadrants::lang::kAdStackMaxReducerMaxAxes; + using quadrants::lang::LlvmAdStackMaxReducerDeviceParams; + + const auto *params = reinterpret_cast(params_blob); + const auto *nodes = reinterpret_cast(body_bytecode); + const auto *indices = reinterpret_cast(reinterpret_cast(nodes) + + sizeof(AdStackSizeExprDeviceNode) * params->body_node_count); + + const char *arg_buffer = ctx->arg_buffer; + DeviceEvalScope scope; + for (i32 k = 0; k < kDeviceBoundVarCap; ++k) { + scope.values[k] = 0; + } + + // Sentinel start: INT64_MIN so the first body value always wins over an empty cross-product. Caller normalises the + // empty case (writes 0 / floors at compile-time) when reading the slot back. + i64 running_max = (i64)0x8000000000000000ll; + const i32 root_idx = params->body_root_node_idx; + const u32 num_axes = params->num_axes; + if (num_axes == 0 || num_axes > (u32)kAdStackMaxReducerMaxAxes) { + runtime->adstack_max_reducer_outputs[params->output_slot] = running_max; + return; + } + // Compute the total cross-product length up front; bail on a zero-length axis to keep the linear walk's mod / div + // decomposition well-defined. + u64 total_length = 1; + for (u32 a = 0; a < num_axes; ++a) { + if (params->per_axis_length[a] == 0u) { + runtime->adstack_max_reducer_outputs[params->output_slot] = running_max; + return; + } + total_length *= (u64)params->per_axis_length[a]; + } + // Walk a single linear counter `i` over `[0, total_length)` and decompose into per-axis indices outermost-first (axis + // 0 is the slowest-varying = highest stride). The host launcher caps `total_length` at u32 max in the dispatch site + // so the linear counter fits in u32 and the per-axis math stays in i64. + for (u64 i = 0; i < total_length; ++i) { + u64 rem = i; + for (u32 a = num_axes; a-- > 0;) { + const u32 len_a = params->per_axis_length[a]; + const u64 idx_a = rem % (u64)len_a; + rem = rem / (u64)len_a; + const i32 var_id = params->per_axis_var_id[a]; + if (var_id >= 0 && var_id < kDeviceBoundVarCap) { + scope.values[var_id] = params->per_axis_begin[a] + (i64)idx_a; + } + } + i64 v = device_eval_node(runtime, nodes, indices, root_idx, &scope, arg_buffer); + if (v > running_max) { + running_max = v; + } + } + runtime->adstack_max_reducer_outputs[params->output_slot] = running_max; +} + void runtime_eval_adstack_size_expr(LLVMRuntime *runtime, RuntimeContext *ctx, Ptr bytecode) { // Bytecode layout: // [AdStackSizeExprDeviceHeader][stack_headers[n_stacks]][nodes[total_nodes]][indices[total_indices]]. All three @@ -1217,7 +1332,7 @@ void runtime_eval_adstack_size_expr(LLVMRuntime *runtime, RuntimeContext *ctx, P // No symbolic bound captured (offline-cache-hit with `size_exprs` dropped) - use the compile-time bound. max_size = sh.max_size_compile_time > 0 ? sh.max_size_compile_time : 1; } else { - i64 v = device_eval_node(nodes, indices, sh.root_node_idx, &scope, arg_buffer); + i64 v = device_eval_node(runtime, nodes, indices, sh.root_node_idx, &scope, arg_buffer); // Floor at 1 to match the host evaluator (`evaluate_adstack_size_expr`); a tree that evaluates to 0 or negative // leaves one slot reserved so the heap base address is still valid and any spurious push surfaces as an overflow // rather than a zero-slice alias. Do NOT clamp upward against `max_size_compile_time`: the compile-time seed is a diff --git a/quadrants/transforms/static_adstack_analysis.h b/quadrants/transforms/static_adstack_analysis.h index a540c9e7ca..dbb825fbd1 100644 --- a/quadrants/transforms/static_adstack_analysis.h +++ b/quadrants/transforms/static_adstack_analysis.h @@ -131,6 +131,45 @@ struct StaticAdStackBoundExpr { loop_iter_size_expr); }; +// Captured `MaxOverRange` reducible by a dedicated parallel max-reducer dispatch at launch time. The recognized grammar +// `MaxOverRange(begin, end, body)` where `begin` and `end` evaluate to closed-form scalars after recursive substitution +// of any deeper captured `MaxOverRange`s, and `body` is integer-typed arithmetic (`Const`, `ExternalTensorRead(arg, +// [BoundVar(this_var)])`, `Add` / `Sub` / `Mul` / `Max` of those). The runtime dispatches one reducer per spec in +// dependency order (deepest first); the per-launch result is substituted as a `Const` into the SizeExpr tree so the +// per-thread sizer never walks the iteration domain. Anything outside the grammar is left for the existing capped path +// (silent truncation today; tracked as future work). +struct StaticAdStackMaxReducerSpec { + // Index of the alloca within `AdStackSizingAttribs::allocas` (same indexing the per-thread sizer uses). + int32_t stack_id{-1}; + // Index of the OUTERMOST `MaxOverRange` node in this alloca's `size_expr.nodes`. The runtime keys results by + // `(task_id_in_kernel, stack_id, mor_node_idx)` and the substitution helper replaces `nodes[mor_node_idx]` with a + // `Const` carrying the dispatched reducer's output. When a chain of nested `MaxOverRange`s is captured as a single + // multi-axis spec, this is the outermost node (axis 0); the inner nodes collapse into the per-axis arrays below and + // are not separately substituted. + int32_t mor_node_idx{-1}; + // Body subtree root (the innermost `MaxOverRange`'s body for multi-axis specs). Walked at launch time to extract the + // arg-id paths the reducer reads from. The body may reference any of the `axis_var_ids` below as bound variables; the + // encoder remaps each to a dense device-scope slot in `[0, axis_var_ids.size())`. + int32_t body_node_idx{-1}; + // Per-axis iteration ranges and bound-variable ids, ORDERED outermost-first (axis 0 = the spec's outermost + // `MaxOverRange`, axis N-1 = the innermost). The dispatch iterates the cross-product of these ranges; each `[begin, + // end)` must evaluate closed-form at dispatch time (after recursive substitution of any deeper captured + // `MaxOverRange` ancestors). Single-axis specs have one entry per vector. + std::vector axis_begin_node_idxs; + std::vector axis_end_node_idxs; + std::vector axis_var_ids; + // Indices into `size_expr.nodes` that are deeper captured `MaxOverRange` specs this one depends on. The runtime + // dispatches in topological order so all dependencies have been substituted before this spec's body is read. + std::vector dependent_mor_node_idxs; + QD_IO_DEF(stack_id, + mor_node_idx, + body_node_idx, + axis_begin_node_idxs, + axis_end_node_idxs, + axis_var_ids, + dependent_mor_node_idxs); +}; + // SNode descriptor info the analysis needs to capture an SNode-backed gate. The resolver returns `std::nullopt` when // the leaf / dense pair has no compile-time descriptor available (e.g. on backends that walk the SNode tree at // runtime), in which case the analysis rejects the gate and the runtime falls back to worst-case sizing. diff --git a/tests/cpp/codegen/adstack_max_reducer_shader_test.cpp b/tests/cpp/codegen/adstack_max_reducer_shader_test.cpp new file mode 100644 index 0000000000..2280a22619 --- /dev/null +++ b/tests/cpp/codegen/adstack_max_reducer_shader_test.cpp @@ -0,0 +1,70 @@ +// `quadrants/common/logging.h` must come first: it pulls in `` which declares `fmt::formatter`, and +// `rhi/public_device.h` specialises `fmt::formatter` without its own include of fmt. Swapping the include +// order here produces a cryptic "use of undeclared identifier 'fmt'" in `public_device.h`. +#include "quadrants/common/logging.h" + +#include +#include + +#include "gtest/gtest.h" +#include "quadrants/codegen/spirv/adstack_max_reducer_shader.h" +#include "quadrants/rhi/public_device.h" + +// Builds the adstack max-reducer SPIR-V binary with a synthetic capability set that matches a PSB+Int64-capable device +// and writes the word stream to a temporary file. The CI does not run `spirv-val` automatically, but dumping the binary +// makes it trivial to validate / disassemble the output during local debugging: spirv-val +// /tmp/adstack_max_reducer.spv spirv-dis /tmp/adstack_max_reducer.spv | head -200 +namespace quadrants::lang::spirv { + +TEST(AdStackMaxReducerShader, DumpBinary) { + DeviceCapabilityConfig caps; + caps.set(DeviceCapability::spirv_version, 0x10400); + caps.set(DeviceCapability::spirv_has_int64, 1); + caps.set(DeviceCapability::spirv_has_physical_storage_buffer, 1); + + auto binary = build_adstack_max_reducer_spirv(Arch::vulkan, &caps); + ASSERT_FALSE(binary.empty()); + + const char *out_path = "/tmp/adstack_max_reducer.spv"; + std::ofstream f(out_path, std::ios::binary); + f.write(reinterpret_cast(binary.data()), binary.size() * sizeof(uint32_t)); + f.close(); + std::fprintf(stderr, "[adstack_max_reducer_test] wrote %zu words (%zu bytes) to %s\n", binary.size(), + binary.size() * sizeof(uint32_t), out_path); +} + +// Pins that the two required capabilities are gated at the top of `build_adstack_max_reducer_spirv`: dropping either +// PSB or Int64 flips the return to empty so the dispatch site (`GfxRuntime::dispatch_max_reducers` in +// `runtime/gfx/adstack_max_reducer_launch.cpp`) early-returns an empty result map and the captured `MaxOverRange` falls +// back through the per-task sizer's existing capped path instead of feeding invalid SPIR-V to a pipeline factory that +// would assert at create time. PSB is required because every body leaf reads through the ndarray data pointer the +// kernel arg buffer carries (PSB load); Int64 is required because the body interpreter widens every integer leaf to +// i64 and the begin / per-axis-begin reassembly arithmetic uses 64-bit operations. The output atomic itself is u32 so +// no atomic-i64 capability is needed. +TEST(AdStackMaxReducerShader, GateReturnsEmptyWhenRequiredCapIsMissing) { + auto make_caps = []() { + DeviceCapabilityConfig caps; + caps.set(DeviceCapability::spirv_version, 0x10400); + caps.set(DeviceCapability::spirv_has_int64, 1); + caps.set(DeviceCapability::spirv_has_physical_storage_buffer, 1); + return caps; + }; + + { + auto caps = make_caps(); + caps.set(DeviceCapability::spirv_has_physical_storage_buffer, 0); + EXPECT_TRUE(build_adstack_max_reducer_spirv(Arch::vulkan, &caps).empty()); + } + { + auto caps = make_caps(); + caps.set(DeviceCapability::spirv_has_int64, 0); + EXPECT_TRUE(build_adstack_max_reducer_spirv(Arch::vulkan, &caps).empty()); + } + // Sanity: all required caps present still builds a non-empty binary. + { + auto caps = make_caps(); + EXPECT_FALSE(build_adstack_max_reducer_spirv(Arch::vulkan, &caps).empty()); + } +} + +} // namespace quadrants::lang::spirv diff --git a/tests/python/test_adstack.py b/tests/python/test_adstack.py index 35d9a9a33c..42d657a568 100644 --- a/tests/python/test_adstack.py +++ b/tests/python/test_adstack.py @@ -11,6 +11,7 @@ import pytest import quadrants as qd +from quadrants.lang import impl from quadrants.lang.exception import QuadrantsAssertionError from quadrants.lang.misc import is_extension_supported @@ -4710,3 +4711,438 @@ def compute(x: qd.types.NDArray, selector: qd.types.NDArray, out: qd.types.NDArr f"gated index {i} (past advisory_total_num_threads={advisory_cap}) gradient diverged: " f"got={got[i]} expected={expected[i]}" ) + + +@pytest.mark.parametrize( + "shape, body_kind", + # `shape` selects whether the per-task sizer's `1<<24` host-eval cap fires; the smaller shape stays well below the + # cap, the larger one crosses it. `body_kind` selects which body-leaf and combinator mix the recognizer must accept + # and the encoder must lower correctly before the device walk. Each `(shape, body_kind)` combination is designed so + # the body's max value over the captured ndarray is always `N_X`, keeping the asserted gradient identical across the + # matrix. + [ + (256, "extread"), + ((1 << 24) + 1, "extread"), + ((1 << 24) + 1, "shape_in_body"), + ((1 << 24) + 1, "field_in_body"), + ((1 << 24) + 1, "arith_combine"), + ], + ids=[ + "small_extread", + "above_cap_extread", + "above_cap_shape_in_body", + "above_cap_field_in_body", + "above_cap_arith_combine", + ], +) +@test_utils.test(require=qd.extension.adstack, cfg_optimization=False) +def test_max_reducer_pins_stride_for_oversized_axis(shape, body_kind): + # A reverse-mode kernel with a parallel-for over an arbitrarily large ndarray axis and an inner range-for bound to a + # recognizer-accepted trip-count expression sizes its adstack at launch time and computes the right gradient, + # without the per-task sizer's `1<<24` cap firing. + # + # Internal details: the kernel lowers to `MaxOverRange(0, a.shape[0], )` in the per-stack `SizeExpr`. + # `recognize_adstack_max_reducer_specs` captures the spec; the launcher dispatches the parallel max-reducer before + # the per-task sizer walks the tree; `substitute_precomputed_max_over_range` rewrites the captured `MaxOverRange` to + # `Const`. The above-cap variants place the only non-zero cell at `arr_np[-1] = N_X` so heap-stride correctness + # depends on the dispatch walking every element of the axis rather than relying on a partial host-eval walk. The + # `shape_in_body` / `field_in_body` variants additionally pin that closed leaves (`ExternalTensorShape`, + # `FieldLoad`) host-fold to `kConst` at encode time and never reach the device interpreter; `arith_combine` + # exercises every binary combinator (`Add`, `Sub`, `Mul`, `Max`) and `Const` leaf in a single body expression that + # algebraically reduces to `a[i_e]`. + N_X = 4 + arr_np = np.zeros(shape, dtype=np.int32) + arr_np[-1] = N_X + # `qd.ndarray` rather than the numpy passthrough so the underlying device buffer is host-managed by Quadrants; numpy + # passthrough (`kNone` H2D-blit) caps the device-side mirror at backend-specific limits on macOS Metal for arrays + # above ~32 MB, which would prevent the dispatch from observing the cell at `arr_np[-1]` in the above-cap variant. + arr = qd.ndarray(qd.i32, shape=(shape,)) + arr.from_numpy(arr_np) + + x = qd.field(qd.f32, shape=(N_X,), needs_grad=True) + loss = qd.field(qd.f32, shape=(), needs_grad=True) + # Closed `FieldLoad` leaf for the `field_in_body` variant. Set to zero so the body's max value remains `N_X` + # regardless of the body kind, keeping the asserted gradient uniform across the parametrized matrix. + gate = qd.field(qd.i32, shape=()) + gate[None] = 0 + + @qd.kernel + def compute(a: qd.types.ndarray(dtype=qd.i32, ndim=1)): + for i_e in range(a.shape[0]): + # `qd.static(...)` selects the body shape at kernel compile time so each parametrization compiles a + # single-branch kernel; every form has algebraic max value `a[i_e]`. The `arith_combine` form exercises + # `Add` / `Sub` / `Mul` / `Max` / `Const` together: outer `Max` of the two equal sub-expressions `a[i_e] + + # 0` (`Add` + `Const`) and `a[i_e] * 1 - 0` (`Mul` + `Sub` + `Const`). + n = ( + a[i_e] + if qd.static(body_kind == "extread") + else ( + a[i_e] + (a.shape[0] - a.shape[0]) + if qd.static(body_kind == "shape_in_body") + else ( + max(a[i_e], gate[None]) + if qd.static(body_kind == "field_in_body") + else max(a[i_e] + 0, a[i_e] * 1 - 0) + ) + ) + ) + accum = 0.0 + for j in range(n): + accum = accum + x[j] * x[j] + loss[None] += accum + + for i in range(N_X): + x[i] = 0.1 + + prog = impl.get_runtime().prog + prog._reset_max_reducer_dispatch_count() + + compute(arr) + loss.grad[None] = 1.0 + for i in range(N_X): + x.grad[i] = 0.0 + compute.grad(arr) + qd.sync() + + # Only the last outer iteration walks the inner loop; every other iteration contributes nothing. The max-reducer + # dispatch covers every element of `arr` so the heap stride lands at the actual maximum (= N_X), and + # `compute.grad(arr)` plus `qd.sync()` runs to completion. The expected per-slot gradient is `2 * x[k]` since each + # surviving inner iteration contributes `2 * x[k]` to the reverse pass. + assert prog._get_max_reducer_dispatch_count() >= 1 + for k in range(N_X): + assert x.grad[k] == pytest.approx(2 * 0.1, rel=1e-5) + + +@test_utils.test(require=qd.extension.adstack, cfg_optimization=False) +def test_max_reducer_dispatch_counts_advance_on_input_mutation(): + # Pins the dispatch + cache invalidation pipeline. The first launch must fire at least one max-reducer dispatch (the + # kernel's `MaxOverRange(0, a.shape[0], a[var])` matches the recognizer grammar so the recognizer captures the spec; + # the launcher dispatches once and bumps `Program.max_reducer_dispatch_count`). A subsequent host mutation of the + # gating ndarray must bump `ndarray_data_gen` and force the next launch to re-dispatch, advancing the counter beyond + # its post-first-launch value. Steady-state cache short-circuit on an unchanged ndarray is backend-dependent (the + # CPU launcher's `set_host_accessible_ndarray_ptrs` path converts qd.ndarray reads to `kNone` semantics and + # `bump_writes_for_kernel_llvm` then bumps the gen on every read; the SPIR-V launchers preserve the qd.ndarray + # dev-alloc-type and only bump on writes), so this test asserts only the mutation-triggers-redispatch contract that + # holds uniformly. + N = 4 + + x = qd.field(qd.f32, shape=(N,), needs_grad=True) + y = qd.field(qd.f32, shape=(), needs_grad=True) + + @qd.kernel + def compute(a: qd.types.ndarray(dtype=qd.i32, ndim=1)): + for i in range(a.shape[0]): + v = x[i] + n = a[i] + for _ in range(n): + v = v * 0.95 + 0.01 + y[None] += v + + a = qd.ndarray(qd.i32, shape=(N,)) + a.from_numpy(np.array([2, 3, 1, 2], dtype=np.int32)) + for i in range(N): + x[i] = 0.1 + + prog = impl.get_runtime().prog + prog._reset_max_reducer_dispatch_count() + + compute(a) + y.grad[None] = 1.0 + for i in range(N): + x.grad[i] = 0.0 + compute.grad(a) + qd.sync() + after_first = prog._get_max_reducer_dispatch_count() + assert after_first >= 1 + + a.from_numpy(np.array([3, 3, 1, 2], dtype=np.int32)) + pre_mutation = prog._get_max_reducer_dispatch_count() + compute(a) + y.grad[None] = 1.0 + for i in range(N): + x.grad[i] = 0.0 + compute.grad(a) + qd.sync() + assert prog._get_max_reducer_dispatch_count() > pre_mutation + + +@test_utils.test(require=qd.extension.adstack, cfg_optimization=False) +def test_max_reducer_grammar_fallback(): + # Pins the recognizer's grammar gate. A reverse-mode kernel whose inner trip count is a compile-time constant (no + # `MaxOverRange` wrapper in the resulting `SizeExpr`) does not match the recognizer grammar and there is no spec for + # `recognize_adstack_max_reducer_specs` to capture. The launcher's pre-publish dispatch finds an empty + # `max_reducer_specs` list, fires no max-reducer dispatch, and the per-task sizer's existing host / device evaluator + # handles the constant trip count via its `Const` leaf path. The dispatch counter must stay at zero and the + # analytical gradient must still match. Pins the "any kernel outside the captured grammar runs unchanged" contract + # so future grammar broadening cannot silently drop the fallback path. + N = 4 + K = 3 + + x = qd.field(qd.f32, shape=(N,), needs_grad=True) + y = qd.field(qd.f32, shape=(), needs_grad=True) + + @qd.kernel + def compute(): + for i in range(N): + v = x[i] + for _ in range(K): + v = v * 0.95 + 0.01 + y[None] += v + + for i in range(N): + x[i] = 0.1 + + prog = impl.get_runtime().prog + prog._reset_max_reducer_dispatch_count() + + compute() + y.grad[None] = 1.0 + for i in range(N): + x.grad[i] = 0.0 + compute.grad() + qd.sync() + + assert prog._get_max_reducer_dispatch_count() == 0 + expected = 0.95**K + for i in range(N): + assert x.grad[i] == pytest.approx(expected, rel=1e-5) + + +@pytest.mark.parametrize( + "body_kind", + [ + "field_bv", + "field_bv_plus_arr_bv", + "arr_bv_plus_field_bv", + "max_field_bv_arr_bv", + "max_field_bv_const", + "field_bv_arith_combine", + "field_bv_indexed_by_field_load", + "arr_bv_indexed_by_field_load", + ], +) +@test_utils.test(require=qd.extension.adstack) +def test_max_reducer_field_load_bound_var_dispatch(body_kind): + # A reverse-mode kernel whose inner range-for trip count reads a `qd.field` indexed by the outer chain variable + # captures via the parallel max-reducer dispatch and produces the analytical gradient. The body-shape + # parametrization exercises every supported composition: bound-var FieldLoad on its own, mixed with bound-var ETR + # via `Add` / `Max`, combined with `Const` / arithmetic, and the nested-load worst-case form (`field[field[i]]` / + # `arr[field[i]]`). + # + # Internal details: each variant lowers to `MaxOverRange(0, M, body)` where `body` is bound-var-indexed + # `FieldLoad(field_a, [bound_var])` or a recognizer-accepted composition that includes one. The relaxed + # `max_reducer_body_is_recognizable::FieldLoad` arm accepts the leaf, the encoder emits a `kFieldLoad` device node + # whose base pointer is pre-resolved on host (PSB on SPIR-V, `runtime->roots[id] + place_byte_offset` on LLVM), + # and the dispatch reads `field_a[i]` for every `i` and keeps the max. The two `_indexed_by_field_load` variants + # exercise the conservative-wrapper path: `SerializedSizeExprNode::indices` carries one int32 per axis (no + # subtree refs), so the trip-count builder substitutes `MaxOverRange(var, 0, leaf_snode.shape, body=Load(snode, + # [var]))` that iterates the leaf snode's full axis - the recognizer captures it via the same bound-var route and + # the dispatched max equals `max_k field_a[k]` (resp. `max_k arr[k]`). Across all variants the body's max value + # over the indexed range is `N_X`, keeping the asserted gradient identical. + N_X = 4 + M = 8 + # Field-a holds the bound-var-indexed counter values: peak value `N_X` lands at the last cell, so a per-element walk + # is necessary to observe the heap-stride correctness; a partial walk that stops at the first non-zero cell would + # under-bound the heap stride. + field_a = qd.field(qd.i32, shape=(M,)) + field_a_init = np.zeros(M, dtype=np.int32) + field_a_init[-1] = N_X + for i in range(M): + field_a[i] = int(field_a_init[i]) + # Field-b is the inner-index source for the `_indexed_by_field_load` variants. Setting every cell to the index of + # field_a's peak (M-1) routes every outer iteration to the cell holding `N_X`; the dispatch's worst-case wrapper + # walks field_a's full axis regardless, so the max reduction still observes `N_X` and the gradient stays uniform. + field_b = qd.field(qd.i32, shape=(M,)) + for i in range(M): + field_b[i] = M - 1 + arr = qd.ndarray(qd.i32, shape=(M,)) + arr.from_numpy(field_a_init) + + x = qd.field(qd.f32, shape=(N_X,), needs_grad=True) + loss = qd.field(qd.f32, shape=(), needs_grad=True) + + @qd.kernel + def compute(a: qd.types.ndarray(dtype=qd.i32, ndim=1)): + for i_e in range(M): + # Each variant is an algebraic identity over the value at `field_a[i_e]` (or `field_a[field_b[i_e]]` for the + # nested-load forms): max value over the captured axis is `N_X` so the asserted gradient stays uniform. + n = ( + field_a[i_e] + if qd.static(body_kind == "field_bv") + else ( + field_a[i_e] + (a[i_e] - a[i_e]) + if qd.static(body_kind == "field_bv_plus_arr_bv") + else ( + a[i_e] + (field_a[i_e] - a[i_e]) + if qd.static(body_kind == "arr_bv_plus_field_bv") + else ( + max(field_a[i_e], a[i_e]) + if qd.static(body_kind == "max_field_bv_arr_bv") + else ( + max(field_a[i_e], 0) + if qd.static(body_kind == "max_field_bv_const") + else ( + max(field_a[i_e] + 0, field_a[i_e] * 1 - 0) + if qd.static(body_kind == "field_bv_arith_combine") + else ( + field_a[field_b[i_e]] + if qd.static(body_kind == "field_bv_indexed_by_field_load") + else a[field_b[i_e]] + ) + ) + ) + ) + ) + ) + ) + accum = 0.0 + for j in range(n): + accum = accum + x[j] * x[j] + loss[None] += accum + + for i in range(N_X): + x[i] = 0.1 + + prog = impl.get_runtime().prog + prog._reset_max_reducer_dispatch_count() + + compute(arr) + loss.grad[None] = 1.0 + for i in range(N_X): + x.grad[i] = 0.0 + compute.grad(arr) + qd.sync() + + # Only one outer iteration walks the inner loop with a non-zero count (the cell at position `M-1` for the direct + # variants, or every iteration via field_b -> field_a[M-1] for the nested variants); each surviving inner + # iteration contributes `2 * x[k]` to `x.grad[k]`. The recognizer captures every variant via the bound-var + # FieldLoad / ETR path so the dispatch counter must advance. + assert prog._get_max_reducer_dispatch_count() >= 1 + if body_kind in ("field_bv_indexed_by_field_load", "arr_bv_indexed_by_field_load"): + # Nested-load worst-case: every outer iteration routes to the peak cell so the reverse pass accumulates `M` + # times. + expected = 2 * 0.1 * M + else: + expected = 2 * 0.1 + for k in range(N_X): + assert x.grad[k] == pytest.approx(expected, rel=1e-5) + + +@test_utils.test(require=qd.extension.adstack) +def test_max_reducer_field_load_bound_var_cache_invalidates_on_snode_mutation(): + # A reverse-mode kernel whose inner trip count reads a `qd.field` indexed by the outer chain variable redispatches + # the max-reducer when the gating field is mutated between launches. + # + # Internal details: the encoder emits a `kFieldLoad` device node and pushes a `FieldLoadObs` carrying the snode id + # and the live `snode_write_gen` snapshot. On the second launch's `try_max_reducer_cache_hit`, + # `replay_one_observation`'s `FieldLoadObs` arm fast-skips on a matching gen and otherwise falls through to the + # invalidate path (`obs.indices == {}` means the gen counter is the sole staleness signal for max-reducer body + # observations). Mutating `field_a[M-1]` from Python bumps `snode_write_gen` so the second launch's replay + # invalidates the entry and the dispatch counter advances beyond `after_first`. + M = 8 + N_X = 4 + + field_a = qd.field(qd.i32, shape=(M,)) + for i in range(M): + field_a[i] = 0 + field_a[M - 1] = 2 + + x = qd.field(qd.f32, shape=(N_X,), needs_grad=True) + loss = qd.field(qd.f32, shape=(), needs_grad=True) + + @qd.kernel + def compute(): + for i_e in range(M): + n = field_a[i_e] + accum = 0.0 + for j in range(n): + accum = accum + x[j] * x[j] + loss[None] += accum + + for i in range(N_X): + x[i] = 0.1 + + prog = impl.get_runtime().prog + prog._reset_max_reducer_dispatch_count() + + compute() + loss.grad[None] = 1.0 + for i in range(N_X): + x.grad[i] = 0.0 + compute.grad() + qd.sync() + after_first = prog._get_max_reducer_dispatch_count() + assert after_first >= 1 + + # Bump field_a's peak value to force a different max; the snode write must bump `snode_write_gen` and the next + # launch's cache replay must invalidate, advancing the dispatch counter. + field_a[M - 1] = 4 + pre_mutation = prog._get_max_reducer_dispatch_count() + compute() + loss.grad[None] = 1.0 + for i in range(N_X): + x.grad[i] = 0.0 + compute.grad() + qd.sync() + assert prog._get_max_reducer_dispatch_count() > pre_mutation + + +@test_utils.test(require=qd.extension.adstack, cfg_optimization=False) +def test_above_cap_out_of_grammar_kernel_raises(): + # A reverse-mode kernel whose inner `range(...)` trip count is bound to an out-of-grammar `MaxOverRange` body and + # whose iteration count exceeds the `1<<24` adstack-sizer cap surfaces a `QuadrantsAssertionError` at `qd.sync()`. + # + # Internal details: the recognizer's body grammar accepts only `Const / ExternalTensorRead / Add / Sub / Mul / Max + # / ExternalTensorShape / FieldLoad(literal-or-bound-var indices)`, and `max_reducer_body_is_recognizable` further + # restricts `ExternalTensorRead` leaves to dtypes whose value range cannot collide with the cache-revalidation + # sentinel (`INT64_MIN`) - `i8 / i16 / i32 / u8 / u16 / u32` only. An `i64` ndarray read passes the host evaluator + # (`evaluate_node`'s `ExternalTensorRead` arm reads any integer dtype) but fails the recognizer's dtype check, so + # the whole spec is dropped and the per-task sizer walks the outer `MaxOverRange` itself. With `a.shape[0] > + # 1<<24` the cap fires on the host evaluator (`QD_ERROR_IF` in `adstack_size_expr_eval.cpp::evaluate_node`, raised + # as `RuntimeError` on the CPU host fast path) and on the SPIR-V on-device sizer (the trailing overflow-flag slot + # of the metadata buffer, raised as `QuadrantsAssertionError` from the host post-readback in + # `publish_adstack_metadata_spirv`). The CUDA and AMDGPU LLVM-GPU sizer short-circuits the walk and returns 0 from + # `device_eval_node`'s `kMaxOverRange` arm so the single-thread on-device dispatch stays within the driver's TDR + # window; the cap-hit then surfaces indirectly via the existing `stack_push` overflow infrastructure on a + # subsequent main-kernel launch, and the resulting diagnostic message attribution depends on the kernel layout. + # That indirect path is covered by `test_adstack_overflow_diagnostic_and_auto_recovery`. + N_X = 4 + shape = (1 << 24) + 1 + # All-zero gating ndarray keeps the forward kernel's actual inner-loop work at zero on every thread; the cap-hit is + # purely a property of the symbolic `MaxOverRange` iteration count, so we do not need any cell to be non-zero for + # the per-task sizer's walk to overflow the guard. + a_data = np.zeros(shape, dtype=np.int64) + a = qd.ndarray(qd.i64, shape=(shape,)) + a.from_numpy(a_data) + + x = qd.field(qd.f32, shape=(N_X,), needs_grad=True) + loss = qd.field(qd.f32, shape=(), needs_grad=True) + + @qd.kernel + def compute(a: qd.types.ndarray(dtype=qd.i64, ndim=1)): + for i_e in range(a.shape[0]): + # `a` is an `i64` ndarray, so the inner `MaxOverRange`'s `end` is an `ExternalTensorRead` with leaf dtype + # `i64`. `max_reducer_body_is_recognizable` rejects `i64 / u64` leaves (the cache-revalidation sentinel + # `INT64_MIN` is a legal value of an `i64` cell, so a mutated cache entry could false-hit on revalidation). + # The whole spec is dropped and the per-task sizer walks the outer `MaxOverRange(0, shape[0], ...)` itself, + # hits the `1<<24` cap, and raises on every backend. + for j_e in range(a[i_e]): + n = a[j_e] + accum = 0.0 + for k in range(n): + accum = accum + x[k] * x[k] + loss[None] += accum + + for i in range(N_X): + x[i] = 0.1 + + # The host evaluator on CPU raises `RuntimeError` directly from `prog.launch_kernel` (the `QD_ERROR_IF` path + # surfaces as `RuntimeError` to Python); the device sizers raise `QuadrantsAssertionError` from `qd.sync()` once + # the overflow flag is polled. The match-set covers both backends uniformly. + with pytest.raises((QuadrantsAssertionError, RuntimeError)): + compute(a) + loss.grad[None] = 1.0 + for i in range(N_X): + x.grad[i] = 0.0 + compute.grad(a) + qd.sync()