Skip to content

support batched inference (text: list[str]) in generate()#35

Open
twangodev wants to merge 4 commits intoHumeAI:mainfrom
twangodev:feat/batched-generation
Open

support batched inference (text: list[str]) in generate()#35
twangodev wants to merge 4 commits intoHumeAI:mainfrom
twangodev:feat/batched-generation

Conversation

@twangodev
Copy link
Copy Markdown

Summary

model.generate(text=list[str]) is a type-advertised API but doesn't actually work today: passing a multi-element list IndexErrors on prompt.text_tokens_len[1] (text_tokens_len is shape [1]), or — with a hand-built prompt to dodge that — silently discards everything past text[0] because the tokenization block two lines later uses only text[0].

This PR makes the advertised behavior actually happen: one forward pass, returns output.audio as a same-length list in input order.

Addresses the batched-inference ask in #2. Tokenization seam handling matches the fix in #16.

Usage

texts = [                                                                                   
    "Hey, how's it going?",
    "The conference starts at nine in the morning.",                                                                                         
    "So I was at the coffee shop, and the barista asked me how my weekend was...",
]                                                                                                                                            
                                                                                            
output = model.generate(prompt=prompt, text=texts, num_transition_steps=5)
                                                                                                                                           
# output.audio is a list[Tensor], one per input text, same order
for wav in output.audio:                                                                                                                     
    torchaudio.save(..., wav.cpu().unsqueeze(0), 24000)

Inputs are right-padded internally; shorter samples exit via EOS and don't bleed into the tail of longer ones. model.compile() stacks
cleanly on top — use it for offline throughput.

See inference.ipynb section 12 for a runnable demo with RTF readout.

What changed

_generate() — the AR loop

  • New attention_mask kwarg, built from pad tokens in generate().
  • Explicit per-sample position_ids once step >= max_input_len, so right-padded rows get RoPE at L_i + k, not max_len + k.
  • finished tracker zeros acoustic_features / time_* once a sample hits EOS, to stop ODE feedback from drifting into the zombie tail.
  • KV cache allocated at input_ids.shape[0], not hardcoded 1.
  • 3 unsqueeze(0)unsqueeze(-1) / (1) fixes on time / acoustic tensors — these were (1, 1) at B=1 and wrong for B>1.

generate() — the public wrapper

  • Per-sample right-padded tokenization with <|finetune_right_pad_id|>.
  • Joint tokenization encode(prompt + " " + gen) mirrors streaming support #16 fix. Kept as its own commit (1135248) for easy review.
  • Post-hoc audio trim: end_frames = real_token_lens - start_offset - shift_acoustic, clamped. Drops the shift_acoustic zombie frames short samples would otherwise leak at the tail.

inference.ipynb

  • Adds section 12 with a mixed-length B=3 demo and RTF print.

Performance

RTX 6000 Blackwell, tada-3b-ml, bf16, num_flow_matching_steps=10:

Config Throughput Peak GB
B=1 ~13× RT 11
B=64 185× RT 11
B=64 + model.compile() 210× RT 13
B=256 + compile 282× RT 23
B=512 + compile 290× RT 37

encode(prompt + " " + gen) avoids BPE merges across the seam that
can shift the first few gen tokens. matches the fix in PR HumeAI#16
(streaming support).
section 12 shows `model.generate(text=[...])` with a mixed-length
batch of 3 texts, prints wall/audio/RTF, and plays each waveform.
Copilot AI review requested due to automatic review settings April 24, 2026 22:30
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Implements true batched text inference for model.generate(text=list[str]) by right-padding inputs, adding attention masking/position handling in the AR loop, and returning per-sample audio outputs in input order.

Changes:

  • Add attention_mask support to _generate() and thread it through prefill + per-step forward calls.
  • Update generate() to jointly tokenize prompt + text per sample, right-pad to a batch, and build an attention mask.
  • Trim decoded audio per-sample based on the real (unpadded) token length to avoid “zombie tail” frames.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tada/modules/tada.py Outdated
Comment thread tada/modules/tada.py
Comment thread tada/modules/tada.py
Comment thread tada/modules/tada.py
prefix_text_tokens is a 2-D tensor after expand(B, -1), so len() returns
B, not prefix_len. Pre-existing bug surfaced by the B>1 expand; input_lengths
isn't read downstream in _generate() so no behavior impact today.
@twangodev twangodev force-pushed the feat/batched-generation branch from 6286ed4 to 24245b2 Compare April 24, 2026 23:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants