Skip to content

add input embedding for pg#956

Open
parkerpang wants to merge 2 commits intoalibaba:mainfrom
parkerpang:feature/input_embedding_pg
Open

add input embedding for pg#956
parkerpang wants to merge 2 commits intoalibaba:mainfrom
parkerpang:feature/input_embedding_pg

Conversation

@parkerpang
Copy link
Copy Markdown

  • 支持 generate input 携带 input_embeddings,允许用户传入额外的 embedding 向量替换指定位置的 token embedding
  • 完整数据通路:Proto 定义 → RPC 序列化/反序列化 → GenerateStream 访问器 → NormalModelInputGatherer 收集
  • 在 setReuseLength 中增加保护,将 reuseLength cap 到 min(embedding_locs),防止 gather 时产生负索引

sw072 and others added 2 commits April 30, 2026 12:00
- Cap reuseLength to min(input_embeddings_locs) in setReuseLength to
  prevent negative adjusted locations during gather
- Fix test compilation: replace nonexistent setRunning()/batch_block_id
  with correct APIs matching existing test patterns
- Add testInputEmbeddingsReuseLengthCapped test case

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@parkerpang parkerpang requested a review from LLLLKKKK as a code owner April 30, 2026 04:02
@LLLLKKKK
Copy link
Copy Markdown
Collaborator

AI Code Review - PR #956

Status: LGTM

Summary: P0/0 · P1/0 · P2/2 · P3/2

lgtm ready to ci

Non-blocking Suggestions

P2

  • transQuery 缺少 embeddings_size 与 embedding_locs_size 一致性校验 @ rtp_llm/cpp/model_rpc/QueryConverter.cc:128
    • 建议:在转换前断言 input_embeddings_pb.embeddings_size() == input_embeddings_pb.embedding_locs_size(),不一致直接 throw(参考 EmbeddingQuery.cc 中对 input_embeddings 的形状校验)。
  • setReuseLength 中 cap 后 reuse_mm_length_ 未重新计算 @ rtp_llm/cpp/engine_base/stream/GenerateStream.cc:320
    • 建议:将 input_embeddings cap 移到函数开头(先 cap reuse_length_,再走 mm_locs 循环),或在 cap 之后重算 reuse_mm_length_,并在测试中覆盖 mm + input_embeddings 同时存在的场景。

P3

  • InputEmbeddings 类型注解与同模块风格不一致 @ rtp_llm/utils/base_model_datatypes.py:48
    • 建议:统一改为 Optional[InputEmbeddings] = None;如 black 行长 88 触发折行,可调整注释长度或与 batch_group_id 上一行同步处理,保持 dataclass 字段块视觉一致。
  • InputEmbeddings 缺少与 proto 一致的不变量校验 @ rtp_llm/utils/base_model_datatypes.py:28
    • 建议:在 init 末尾 assert len(embeddings) == len(embedding_locs),或在 GenerateInput.validate() 路径加入校验,避免错误数据被序列化进 protobuf 后才暴露。
Review Checklist: 5 pass / 5 fail

General Principles Checklist

Failed

  • [FAIL] 6.1 Architecture: 边界与不变量:跨进程/跨语言数据通道需校验关键不变量 Linked issue: transQuery 缺少 embeddings_size 与 embedding_locs_size 一致性校验
  • [FAIL] 6.1 Tests: 边界场景测试覆盖:组合状态需要专门测试 Linked issue: setReuseLength 中 cap 后 reuse_mm_length_ 未重新计算
  • [FAIL] 6.1 Quality: 修改集中、不混入无关格式化噪声 Linked issue: InputEmbeddings 类型注解与同模块风格不一致

Passed

  • [PASS] 6.1 Software Engineering: DRY/复用:避免与既有路径重复实现

RTP-LLM Checklist

Passed

  • [PASS] A 兼容性与配置: proto/RPC 字段新增需向后兼容
  • [PASS] B 正确性与逻辑: 位置/索引计算需考虑 reuseLength 偏移
  • [PASS] D 性能: 避免在 host/device 之间冗余拷贝
  • [PASS] H 测试与 CI: 新增功能需 C++ 单测

Python Static-First Checklist

Failed

  • [FAIL] P.A 类型与结构: 模块内类型注解风格一致 Linked issue: InputEmbeddings 类型注解与同模块风格不一致
  • [FAIL] P.B 数据校验: 自定义数据容器在 init 检查关键不变量 Linked issue: InputEmbeddings 缺少与 proto 一致的不变量校验

Strengths

  • 测试覆盖较全:FP32/FP16/BF16、与 decode stream 混合、reuseLength 被 cap、空 input_embeddings 等正反场景均有用例。
  • gatherInputEmbeddingsForContextBatch 严格仿照 gatherMultimodalFeaturesForContextBatch 的位置偏移规则(loc - reuseLength + token_idx),与既有 mm 路径一致,避免引入新偏移语义。
  • setReuseLength 中主动按 input_embeddings_locs 下界 cap reuse_length_,从源头上避免下游产生负 loc,是合理的防御式实现。
  • model_rpc proto 字段编号末尾追加(field 9),保持向后兼容。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants