Skip to content

fix: fix rocm greedy sampling to avoid crash#932

Open
liaocz wants to merge 1 commit intomainfrom
fix/rocm_greedy_sampling
Open

fix: fix rocm greedy sampling to avoid crash#932
liaocz wants to merge 1 commit intomainfrom
fix/rocm_greedy_sampling

Conversation

@liaocz
Copy link
Copy Markdown
Collaborator

@liaocz liaocz commented Apr 24, 2026

Motivation

Fix ROCm greedy sampling coredump when requests include top_p parameter. The ROCm sampling path was not aligned with the CUDA path, causing all token probabilities to be masked to zero during top_p filtering, which leads to a crash in torch::multinomial.

Key Changes

1. Add do_sample Flag Check — Align with CUDA Path

  • Temperature penalty and repetition/presence/frequency penalty are now only applied when need_do_sample is true
  • Added has_not_do_sample and need_do_sample boolean checks based on params.do_sample

2. Fix Inplace Modification in Top-K Filtering

  • Changed auto filtered_probs = probs_t to auto filtered_probs = probs_t.clone() to avoid corrupting the original probability tensor, which is needed later for cum_log_probs computation

3. Fix Top-P Filtering Coredump (Root Cause)

  • Explicitly pass dim=-1 to row.sort() to ensure correct sorting behavior
  • Added mask[0] = false to always preserve the highest-probability token, preventing all tokens from being masked to zero — this is the direct cause of the coredump

4. Fix cum_log_probs Computation

Aspect Before After
Logic cum_log_probs_t.add_(probs_t.log()) — applies log over entire distribution cum_log_probs_t.add_(token_probs.log()) — gathers only the sampled token's probability via gather, then applies log
Device safety No device alignment Added .to(cum_log_probs_t.device())

@liaocz liaocz requested a review from LLLLKKKK as a code owner April 24, 2026 07:59
@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review — PR #932

PR 概述

Title: fix: fix rocm greedy sampling to avoid coredump
Author: liaocz
规模: 1(GitHub) files, +22/-10

核心目标

修复 ROCm 平台 legacy sampleGreedy 路径中的多个正确性问题:(1) 缺少 do_sample 标志处理导致不需要采样的请求也被施加 penalty;(2) filtered_probs 未 clone 导致原地修改污染原始概率;(3) sort() API 参数错误;(4) top_p 过滤可能清零所有 token;(5) cum_log_probs 计算逻辑错误。


改动逻辑拆解

GitHub 开源仓库变更(主要代码)

1. do_sample 标志处理(核心逻辑)

新增 has_not_do_sampleneed_do_sample 两个布尔变量,与 CUDA 路径对齐。temperature penalty 和 repetition/presence/frequency penalty 现在仅在 need_do_sample 为 true 时才执行。

2. filtered_probs clone 修复

filtered_probs = probs_t 改为 filtered_probs = probs_t.clone()。原代码中 filtered_probsprobs_t 的别名,top_k/top_p 过滤会原地修改 probs_t,导致后续 cum_log_probs 计算使用被污染的概率值。

3. sort() API 参数修正

row.sort(/*descending=*/true) 改为 row.sort(/*dim=*/-1, /*descending=*/true)。PyTorch C++ sort() 签名为 sort(int64_t dim, bool descending),旧代码将 true(即 1)作为 dim 参数传入,对 1D tensor 虽然不会 crash,但语义不正确。

4. top_p 过滤保留 top-1 token

新增 mask[0] = false,确保 top_p 过滤后至少保留概率最高的 token。原代码在极端 top_p 值下可能将所有 token 概率清零,导致 multinomial 采样 crash(coredump 的直接原因之一)。

5. cum_log_probs 计算修正

原代码 cum_log_probs_t.add_(probs_t.log()) 将整个概率分布的 log 加到 cum_log_probs 上,这是错误的——应该只加被采样 token 的 log 概率。修正为 probs_t.gather(1, samples_t.unsqueeze(1).to(torch::kLong)).squeeze(1) 后取 log。


Checklist 检查结果

通用原则

软件工程原则

检查项 结果
SRP ✅ 单一职责:修复 ROCm 采样路径
OCP
LSP
ISP
DIP
DRY ❌ legacy 路径与 CUDA 路径大量重复逻辑,但这是已有技术债,非本 PR 引入
KISS
YAGNI

架构审视

检查项 结果
抽象边界
依赖方向
状态完整性
错误语义
可观测性
可演进性 ❌ 见 P1-1:新设备抽象层 ROCmSampleOp.cc 未同步修复
可运维性

测试

检查项 结果
新功能有对应测试 ❌ 见 P1-2:无 do_sample 混合 batch 测试
删除的测试有等价替代 ✅ 无删除
边界 case 覆盖 ❌ 缺少 top_p 极端值、全 do_sample=false 等边界测试
分布式改动有多卡测试 ✅ 不涉及

代码质量与文档

检查项 结果
无关改动分离
mega-PR 拆分 ✅ 小 PR
Commit 原子性 ✅ 单 commit 自包含
Commit message 准确性
PR description 说明动机和设计 ❌ PR body 为空,未说明 coredump 的触发条件和修复思路
日志频率控制

领域检查

A. 兼容性与配置 — 全部 ✅

B. 正确性与逻辑

检查项 结果
逻辑错误、off-by-one、null/zero 检查 ✅ 修复正确
边界 case ❌ 见 P2-1:混合 batch 下 do_sample=false 的 item 仍被施加 penalty
死代码 / 不可达分支
共享对象原地修改需浅拷贝(1.22) ✅ clone 修复正确
过滤后使用过滤结果而非原始集合(1.46)

C. 线程安全与并发 — 全部 ✅

D. 性能 — 全部 ✅(clone 引入一次额外拷贝,但仅在非 top_k=1 路径,可接受)

E. 分布式 — 全部 ✅

F. 跨平台(ROCm/ARM)

检查项 结果
CUDA/ROCm 两侧 binding 对称清理(1.19) ❌ 见 P1-1:新 ROCmSampleOp.cc 未同步修复
ROCm 路径错误处理非静默(1.31)

G. 语言与框架特有 — 全部 ✅

H. 测试与 CI

检查项 结果
测试覆盖充分(1.5) ❌ 见 P1-2

I. 代码质量 — 全部 ✅


Review 意见

问题

  1. 新设备抽象层 ROCmSampleOp.cc 存在相同 bug 未修复 [P1]

    rtp_llm/cpp/devices/rocm_impl/ROCmSampleOp.cc 第 206-208 行的 cum_log_probs 计算与本 PR 修复前的 legacy 代码完全相同:

    Buffer2torchTensor(*params.cum_log_probs).add_(probs_t.log());

    这会将整个概率分布的 log 加到 cum_log_probs 上,而非仅加被采样 token 的 log 概率。

    同时该文件也完全缺少 do_sample 标志处理——temperature 和 repetition penalty 无条件施加给所有 batch item。

    建议:将本 PR 的 5 项修复同步到 ROCmSampleOp.cc,或确认该文件是否已废弃不再使用。

  2. 缺少测试覆盖 [P1]

    本 PR 修复了 5 个独立的正确性问题,但未新增任何测试用例。ROCmSamplerTest.cc 中现有测试未覆盖:

    • do_sample 混合 batch(部分 true 部分 false)
    • 极端 top_p 值(接近 0)触发全零概率
    • cum_log_probs 正确性验证

    建议:至少补充一个 do_sample 混合 batch 的测试和一个 cum_log_probs 正确性测试。

  3. 混合 batch 下 do_sample=false 的 item 仍被施加 penalty [P2]

    CUDA 路径在 need_do_sample 为 true 时,会先保存 do_sample=false item 的 logits,执行 penalty 后再恢复:

    if (has_not_do_sample) {
        selected_logits = logits_tensor.masked_select(mask_tensor);
    }
    processLogits(params, ...);
    if (has_not_do_sample) {
        logits_tensor.masked_scatter_(mask_tensor, selected_logits);
    }

    本 PR 的 ROCm 路径虽然计算了 has_not_do_sample,但未使用它做 logits 保存/恢复。在混合 batch 场景下,do_sample=false 的 item 的 logits 仍会被 temperature/repetition penalty 修改。

    建议:补充 logits 保存/恢复逻辑,与 CUDA 路径完全对齐。

  4. PR description 为空 [P2]

    PR body 为空,未说明 coredump 的具体触发条件(哪个场景、什么参数组合)、根因分析和修复思路。对于修复 5 个独立问题的 PR,缺少描述会增加 reviewer 理解成本。

    建议:补充 PR description,说明 coredump 的触发路径和每项修复的动机。

小问题

  • 注释编号不连续:步骤注释从 "0. Check do_sample" 到 "6. Update cum_log_probs" 再到 "8. Copy results back",跳过了 7。建议统一编号。

整体评价

本 PR 修复了 ROCm legacy 采样路径中的多个真实正确性问题,每项修复的方向都是正确的。probs_t.clone()sort() 参数修正、mask[0] = falsecum_log_probs gather 修复都是必要的。主要风险在于:(1) 新设备抽象层 ROCmSampleOp.cc 存在完全相同的 bug 未被修复;(2) 混合 batch 的 logits 保存/恢复逻辑缺失;(3) 无测试覆盖。

存在阻塞/重要问题,不建议合入(P1 x 2)

@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review — PR #932

PR 概述

Title: fix: fix rocm greedy sampling to avoid coredump
Author: liaocz
规模: 1(GitHub) files, +22/-10

核心目标

修复 ROCm 平台 greedy sampling 路径的多个正确性问题,包括:缺少 do_sample 标志检查导致不必要的 penalty 应用、top_p 过滤可能清零所有 token 概率导致 coredump、cum_log_probs 计算错误、以及 filtered_probs 未 clone 导致原始概率被污染。


改动逻辑拆解

GitHub 开源仓库变更(主要代码)

1. ROCm sampleGreedy — do_sample 标志检查(对齐 CUDA 路径)

新增 has_not_do_sampleneed_do_sample 两个布尔变量,与 CUDA 路径的 sampleGreedy 保持一致。temperature penalty 和 repetition penalty 均增加 need_do_sample 前置条件,避免对不需要采样的请求施加 penalty。

2. ROCm sampleGreedy — filtered_probs clone

auto filtered_probs = probs_t;(浅引用)改为 auto filtered_probs = probs_t.clone();(深拷贝)。修复 top_k/top_p 过滤操作会原地修改 probs_t 的问题,probs_t 在后续 cum_log_probs 计算中仍需使用原始概率值。

3. ROCm sampleGreedy — top_p 过滤安全性

  • row.sort(...) 增加显式 dim=-1 参数,避免依赖默认行为。
  • 新增 mask[0] = false; 确保 top-1 token 始终保留。当 top_p 值极小时,cumsum 可能导致所有 token 被 mask 掉,概率全为 0,torch::multinomial 会触发非法内存访问(coredump)。

4. ROCm sampleGreedy — cum_log_probs 修复

旧代码 cum_log_probs_t.add_(probs_t.log()) 对整个概率矩阵取 log 后加到 cum_log_probs(shape 不匹配,行为未定义)。修复为 probs_t.gather(1, samples_t.unsqueeze(1).to(torch::kLong)).squeeze(1) 先提取被选中 token 的概率,再取 log 累加。


Checklist 检查结果

通用原则

软件工程原则

检查项 结果
SRP ✅ 单一职责:修复 ROCm sampling 正确性
OCP ✅ 不涉及
LSP ✅ 不涉及
ISP ✅ 不涉及
DIP ✅ 不涉及
DRY ❌ do_sample 检查逻辑与 CUDA 路径完全重复(见 P2-1)
KISS ✅ 修复逻辑直接明了
YAGNI ✅ 无过度设计

架构审视

检查项 结果
抽象边界
依赖方向
状态完整性
错误语义
可观测性
可演进性 ❌ CUDA 和 ROCm 两套 sampleGreedy 实现长期分叉,维护成本高(见 P2-2)
可运维性

测试

检查项 结果
新功能有对应测试 ❌ 无新增测试覆盖修复场景(见 P1-1)
删除的测试有等价替代 ✅ 不涉及
边界 case 覆盖 ❌ 缺少 top_p 极小值、mixed do_sample batch 的边界测试
分布式改动有多卡测试 ✅ 不涉及

代码质量与文档

检查项 结果
无关改动分离
mega-PR 拆分 ✅ 单文件小改动
Commit 原子性 ✅ 单 commit 自包含
Commit message 准确性 ✅ 准确描述修复意图
PR description 说明动机和设计 ❌ PR body 为空,未说明 coredump 的触发条件和修复思路(见 P2-3)
日志频率控制 ✅ 不涉及

领域检查

A. 兼容性与配置 — 全部 ✅

B. 正确性与逻辑

检查项 结果
逻辑错误 ❌ mixed batch 下 do_sample 处理不完整(见 P1-2)
边界 case ✅ top_p=0 场景已通过 mask[0]=false 修复
共享对象原地修改(1.22) ✅ clone 修复了 probs_t 原地修改问题

C. 线程安全与并发 — 全部 ✅

D. 性能 — 全部 ✅

E. 分布式 — 全部 ✅

F. 跨平台(ROCm/ARM)

检查项 结果
CUDA/ROCm 两侧 binding 对称清理(1.19) ❌ ROCm 路径 do_sample 处理与 CUDA 路径行为不一致(见 P1-2)
ROCm 路径错误处理非静默(1.31)

G. 语言与框架特有 — 全部 ✅

H. 测试与 CI — ❌ 见 P1-1

I. 代码质量 — 全部 ✅


Review 意见

问题

  1. ROCm 路径 mixed batch do_sample 处理与 CUDA 路径行为不一致 [P1]

    CUDA 路径在 mixed batch(部分请求 do_sample=true,部分 do_sample=false)时,会先保存非采样请求的原始 logits,对全 batch 执行 processLogits(temperature + repetition penalty),然后用 masked_scatter_ 恢复非采样请求的 logits。这确保了非采样请求不受 penalty 影响。

    ROCm 路径仅用 need_do_sample 做整体开关:只要 batch 中有任何一个请求需要采样,就对全 batch 施加 temperature 和 repetition penalty。这意味着 do_sample=false 的请求也会被错误地施加 penalty,导致 argmax 结果偏移。

    // CUDA 路径(正确):
    if (need_do_sample) {
        if (has_not_do_sample) {
            selected_logits = params.logits.masked_select(mask_tensor);  // 保存
        }
        processLogits(params, device_tokens, transposed_tokens);
        if (has_not_do_sample) {
            params.logits.masked_scatter_(mask_tensor, selected_logits);  // 恢复
        }
    }
    
    // ROCm 路径(当前):直接对全 batch 施加 penalty,无 mask 保护
    if (need_do_sample && ...) {
        invokeBatchApplyTemperaturePenalty(...);  // 影响所有 batch item
    }

    建议:在 ROCm 路径中增加与 CUDA 路径相同的 masked_select / masked_scatter_ 逻辑,或将 penalty 应用逻辑提取为共享函数。

  2. 缺少测试覆盖 [P1]

    本 PR 修复了 4 个独立的正确性问题(do_sample 检查、clone、top_p mask、cum_log_probs),但未新增任何测试用例。这些修复涉及的场景(特别是 top_p 极小值导致 coredump、mixed do_sample batch)应有回归测试确保不再复现。

    建议:至少新增以下测试场景:

    • top_p 极小值(如 0.01)+ top_k=0 的采样不 crash
    • mixed do_sample batch 中非采样请求的输出不受 penalty 影响
    • cum_log_probs 在采样后值正确(与 CUDA 路径对比)

小问题

  1. do_sample 检查逻辑重复 [P2]
    has_not_do_sampleneed_do_sample 的计算逻辑在 CUDA 和 ROCm 两个 sampleGreedy 中完全相同。建议提取为 namespace {} 内的辅助函数,减少维护时两侧不一致的风险。

  2. CUDA/ROCm sampleGreedy 长期分叉 [P2]
    两个平台的 sampleGreedy 实现已有显著差异(CUDA 使用 flashinfer,ROCm 使用 PyTorch multinomial),但共享逻辑(do_sample 检查、penalty 应用、cum_log_probs 计算)仍在各自重复。建议后续将共享逻辑提取到平台无关的函数中。

  3. PR description 为空 [P2]
    PR body 为 null,未说明 coredump 的具体触发条件(哪个场景、什么输入参数组合)和修复思路。对于 bug fix PR,描述触发条件有助于 reviewer 验证修复完整性。

  4. 注释编号跳跃 [P3]
    步骤注释从 "6. Update cum_log_probs" 直接跳到 "8. Copy results back",缺少步骤 7。

整体评价

本 PR 修复了 ROCm greedy sampling 路径的多个真实正确性问题,特别是 top_p 过滤可能清零所有概率导致 coredump 的问题和 cum_log_probs 计算错误。修复方向正确,clone 和 mask[0]=false 的处理都是必要的。但 mixed batch 下 do_sample 的处理与 CUDA 路径存在行为差异(P1-2),且缺少测试覆盖(P1-1),建议补齐后合入。

存在阻塞/重要问题,不建议合入

@liaocz liaocz changed the title fix: fix rocm greedy sampling to avoid coredump fix: fix rocm greedy sampling to avoid crash Apr 24, 2026
@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review — PR #932

PR 概述

Title: fix: fix rocm greedy sampling to avoid crash
Author: liaocz
规模: 1(GitHub) files, +22/-10

核心目标

修复 ROCm 平台 legacy sampleGreedy 路径中的多个正确性问题:(1) 缺少 do_sample 标志处理;(2) filtered_probs 未 clone 导致原地修改污染原始概率;(3) sort() API 参数错误;(4) top_p 过滤可能清零所有 token 导致 coredump;(5) cum_log_probs 计算逻辑错误。


Review 意见

问题

  1. Mixed batch do_sample 处理与 CUDA 路径行为不一致 [P1]

    CUDA 路径在 mixed batch(部分请求 do_sample=true,部分 do_sample=false)时,会先保存非采样请求的原始 logits,对全 batch 执行 processLogits,然后用 masked_scatter_ 恢复非采样请求的 logits。ROCm 路径仅用 need_do_sample 做整体开关——只要 batch 中有任何一个请求需要采样,就对全 batch 施加 temperature 和 repetition penalty,do_sample=false 的请求的 logits 仍会被错误修改,导致 argmax 结果偏移。

    建议:在 ROCm 路径中增加与 CUDA 路径相同的 masked_select / masked_scatter_ 逻辑保护非采样请求的 logits。

  2. 缺少测试覆盖 [P1]

    本 PR 修复了 5 个独立的正确性问题,但未新增任何测试用例。建议至少补充:

    • mixed do_sample batch(部分 true 部分 false)
    • 极端 top_p 值(如 0.01)+ top_k=0 不 crash
    • cum_log_probs 正确性验证

小问题

  1. do_sample 检查逻辑重复 [P2] — has_not_do_sampleneed_do_sample 计算在 CUDA/ROCm 两侧完全相同,建议提取为共享辅助函数。
  2. CUDA/ROCm sampleGreedy 长期分叉 [P2] — 共享逻辑(do_sample 检查、penalty、cum_log_probs)建议后续提取到平台无关函数。
  3. 注释编号跳跃 [P3] — 步骤从 6 跳到 8,缺少 7。

纠正前次 AI Review 的错误

前次 AI review 提到 ROCmSampleOp.cc 存在相同 bug 需同步修复——该文件不存在。ROCm 的 sampleGreedy 位于 CudaSampleOp.cc#else 分支内(即本 PR 修改的位置)。

整体评价

本 PR 修复了 ROCm greedy sampling 路径的多个真实正确性问题,每项修复方向正确。PR description 质量好。主要风险在于:(1) mixed batch 下 do_sample 处理与 CUDA 路径不一致;(2) 缺少测试覆盖。

存在重要问题,不建议合入(P1 x 2)

Copy link
Copy Markdown
Collaborator

@LLLLKKKK LLLLKKKK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - testing CI gate workflow post-merge

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants