support batched inference (text: list[str]) in generate()#35
Open
twangodev wants to merge 4 commits intoHumeAI:mainfrom
Open
support batched inference (text: list[str]) in generate()#35twangodev wants to merge 4 commits intoHumeAI:mainfrom
twangodev wants to merge 4 commits intoHumeAI:mainfrom
Conversation
encode(prompt + " " + gen) avoids BPE merges across the seam that can shift the first few gen tokens. matches the fix in PR HumeAI#16 (streaming support).
section 12 shows `model.generate(text=[...])` with a mixed-length batch of 3 texts, prints wall/audio/RTF, and plays each waveform.
There was a problem hiding this comment.
Pull request overview
Implements true batched text inference for model.generate(text=list[str]) by right-padding inputs, adding attention masking/position handling in the AR loop, and returning per-sample audio outputs in input order.
Changes:
- Add
attention_masksupport to_generate()and thread it through prefill + per-step forward calls. - Update
generate()to jointly tokenizeprompt + textper sample, right-pad to a batch, and build an attention mask. - Trim decoded audio per-sample based on the real (unpadded) token length to avoid “zombie tail” frames.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
prefix_text_tokens is a 2-D tensor after expand(B, -1), so len() returns B, not prefix_len. Pre-existing bug surfaced by the B>1 expand; input_lengths isn't read downstream in _generate() so no behavior impact today.
6286ed4 to
24245b2
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
model.generate(text=list[str])is a type-advertised API but doesn't actually work today: passing a multi-element list IndexErrors onprompt.text_tokens_len[1](text_tokens_lenis shape[1]), or — with a hand-built prompt to dodge that — silently discards everything pasttext[0]because the tokenization block two lines later uses onlytext[0].This PR makes the advertised behavior actually happen: one forward pass, returns
output.audioas a same-length list in input order.Addresses the batched-inference ask in #2. Tokenization seam handling matches the fix in #16.
Usage
Inputs are right-padded internally; shorter samples exit via EOS and don't bleed into the tail of longer ones.
model.compile()stackscleanly on top — use it for offline throughput.
See
inference.ipynbsection 12 for a runnable demo with RTF readout.What changed
_generate()— the AR loopattention_maskkwarg, built from pad tokens ingenerate().position_idsoncestep >= max_input_len, so right-padded rows get RoPE atL_i + k, notmax_len + k.finishedtracker zerosacoustic_features/time_*once a sample hits EOS, to stop ODE feedback from drifting into the zombie tail.input_ids.shape[0], not hardcoded1.unsqueeze(0)→unsqueeze(-1) / (1)fixes on time / acoustic tensors — these were(1, 1)at B=1 and wrong for B>1.generate()— the public wrapper<|finetune_right_pad_id|>.encode(prompt + " " + gen)mirrors streaming support #16 fix. Kept as its own commit (1135248) for easy review.end_frames = real_token_lens - start_offset - shift_acoustic, clamped. Drops theshift_acousticzombie frames short samples would otherwise leak at the tail.inference.ipynbPerformance
RTX 6000 Blackwell,
tada-3b-ml, bf16,num_flow_matching_steps=10:model.compile()