Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions whisper_live/backend/faster_whisper_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def transcribe_audio(self, input_sample):
use_vad=self.use_vad,
vad_parameters=self.vad_parameters if self.use_vad else None,
word_timestamps=self.word_timestamps,
client_uid=self.client_uid,
)
ServeClientFasterWhisper.BATCH_WORKER.submit(request)
request.future.wait(timeout=30)
Expand Down
97 changes: 75 additions & 22 deletions whisper_live/batch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class BatchRequest:
initial_prompt: Optional[str] = None
use_vad: bool = True
vad_parameters: Optional[Dict] = None
word_timestamps: bool = False
client_uid: Optional[str] = None
# Signaling
future: threading.Event = field(default_factory=threading.Event)
# Results (filled by batch worker)
Expand Down Expand Up @@ -307,36 +309,87 @@ def _process_multi(self, batch: List[BatchRequest]):
tokenizers_list.append(tokenizer)
prompts.append(prompt)

# Step 4: Batch GPU generate
# Step 4: Batch GPU generate with per-item temperature fallback.
# Mirrors faster_whisper.transcribe()'s fallback loop. Items that
# pass quality thresholds at lower temperature keep their result;
# only failed items are re-decoded at the next temperature.
suppress_tokens = get_suppressed_tokens(tokenizers_list[0], [-1])

results = self.transcriber.model.generate(
encoder_output,
prompts,
beam_size=5,
patience=1,
length_penalty=1,
max_length=self.transcriber.max_length,
suppress_blank=True,
suppress_tokens=suppress_tokens,
return_scores=True,
return_no_speech_prob=True,
sampling_temperature=0.0,
repetition_penalty=1,
no_repeat_ngram_size=0,
)
temperatures = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
comp_thresh = 2.4
logprob_thresh = -1.0
no_speech_thresh = 0.6

# Step 5: Per-item segment parsing and result dispatch
for i, (req, features, audio, duration, speech_chunks) in enumerate(preprocessed):
try:
tokenizer = tokenizers_list[i]
gen_result = results[i]
n = len(preprocessed)
final_results = [None] * n # tuples of (gen_result, avg_logprob, used_temp)
pending_indices = list(range(n))

for temp in temperatures:
if not pending_indices:
break

if len(pending_indices) == n:
sub_encoder = encoder_output
else:
# Re-encode features for just the pending items to get
# an encoder_output of the right batch dimension.
sub_feature_batch = np.stack(
[preprocessed[i][1] for i in pending_indices]
)
sub_encoder = self.transcriber.encode(sub_feature_batch)
sub_prompts = [prompts[i] for i in pending_indices]

gen_kwargs = dict(
beam_size=5 if temp == 0.0 else 1,
patience=1,
length_penalty=1,
max_length=self.transcriber.max_length,
suppress_blank=True,
suppress_tokens=suppress_tokens,
return_scores=True,
return_no_speech_prob=True,
sampling_temperature=temp,
repetition_penalty=1,
no_repeat_ngram_size=0,
)
batch_results = self.transcriber.model.generate(
sub_encoder, sub_prompts, **gen_kwargs
)

next_pending = []
for j, idx in enumerate(pending_indices):
gen_result = batch_results[j]
tokens = gen_result.sequences_ids[0]
seq_len = len(tokens)
cum_logprob = gen_result.scores[0] * seq_len
avg_logprob = cum_logprob / (seq_len + 1) if seq_len > 0 else 0.0
raw_text = tokenizers_list[idx].decode(tokens).strip()
comp_ratio = get_compression_ratio(raw_text) if raw_text else 0.0

bad = (
comp_ratio > comp_thresh
or avg_logprob < logprob_thresh
)
# High no_speech + low logprob -> treat as silence, accept empty.
is_silence = (
gen_result.no_speech_prob > no_speech_thresh
and avg_logprob < logprob_thresh
)

if not bad or is_silence or temp == temperatures[-1]:
final_results[idx] = (gen_result, avg_logprob, temp)
else:
next_pending.append(idx)

pending_indices = next_pending

# Step 5: Per-item segment parsing and result dispatch
for i, (req, features, audio, duration, speech_chunks) in enumerate(preprocessed):
try:
tokenizer = tokenizers_list[i]
gen_result, avg_logprob, used_temp = final_results[i]

tokens = gen_result.sequences_ids[0]
segment_size = int(ceil(duration) * self.transcriber.frames_per_second)

subsegments, _, _ = self.transcriber._split_segments_by_timestamps(
Expand Down Expand Up @@ -364,7 +417,7 @@ def _process_multi(self, batch: List[BatchRequest]):
compression_ratio=get_compression_ratio(text),
no_speech_prob=gen_result.no_speech_prob,
words=None,
temperature=0.0,
temperature=used_temp,
))

req.result = segments
Expand Down
3 changes: 3 additions & 0 deletions whisper_live/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,9 @@ def run(self,
logging.info("Custom model option was provided. Switching to single model mode.")
self.single_model = True
# TODO: load model initially
elif batch_enabled:
logging.info("Batch inference enabled. Switching to single model mode for stock model.")
self.single_model = True
else:
logging.info("Single model mode currently only works with custom models.")
if not BackendType.is_valid(backend):
Expand Down
Loading