diff --git a/whisper_live/backend/faster_whisper_backend.py b/whisper_live/backend/faster_whisper_backend.py index 361d489e..8dcd9aaa 100644 --- a/whisper_live/backend/faster_whisper_backend.py +++ b/whisper_live/backend/faster_whisper_backend.py @@ -220,6 +220,7 @@ def transcribe_audio(self, input_sample): use_vad=self.use_vad, vad_parameters=self.vad_parameters if self.use_vad else None, word_timestamps=self.word_timestamps, + client_uid=self.client_uid, ) ServeClientFasterWhisper.BATCH_WORKER.submit(request) request.future.wait(timeout=30) diff --git a/whisper_live/batch_inference.py b/whisper_live/batch_inference.py index a15afb0d..f975dd4a 100644 --- a/whisper_live/batch_inference.py +++ b/whisper_live/batch_inference.py @@ -74,6 +74,8 @@ class BatchRequest: initial_prompt: Optional[str] = None use_vad: bool = True vad_parameters: Optional[Dict] = None + word_timestamps: bool = False + client_uid: Optional[str] = None # Signaling future: threading.Event = field(default_factory=threading.Event) # Results (filled by batch worker) @@ -307,36 +309,87 @@ def _process_multi(self, batch: List[BatchRequest]): tokenizers_list.append(tokenizer) prompts.append(prompt) - # Step 4: Batch GPU generate + # Step 4: Batch GPU generate with per-item temperature fallback. + # Mirrors faster_whisper.transcribe()'s fallback loop. Items that + # pass quality thresholds at lower temperature keep their result; + # only failed items are re-decoded at the next temperature. suppress_tokens = get_suppressed_tokens(tokenizers_list[0], [-1]) - results = self.transcriber.model.generate( - encoder_output, - prompts, - beam_size=5, - patience=1, - length_penalty=1, - max_length=self.transcriber.max_length, - suppress_blank=True, - suppress_tokens=suppress_tokens, - return_scores=True, - return_no_speech_prob=True, - sampling_temperature=0.0, - repetition_penalty=1, - no_repeat_ngram_size=0, - ) + temperatures = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] + comp_thresh = 2.4 + logprob_thresh = -1.0 + no_speech_thresh = 0.6 - # Step 5: Per-item segment parsing and result dispatch - for i, (req, features, audio, duration, speech_chunks) in enumerate(preprocessed): - try: - tokenizer = tokenizers_list[i] - gen_result = results[i] + n = len(preprocessed) + final_results = [None] * n # tuples of (gen_result, avg_logprob, used_temp) + pending_indices = list(range(n)) + + for temp in temperatures: + if not pending_indices: + break + + if len(pending_indices) == n: + sub_encoder = encoder_output + else: + # Re-encode features for just the pending items to get + # an encoder_output of the right batch dimension. + sub_feature_batch = np.stack( + [preprocessed[i][1] for i in pending_indices] + ) + sub_encoder = self.transcriber.encode(sub_feature_batch) + sub_prompts = [prompts[i] for i in pending_indices] + + gen_kwargs = dict( + beam_size=5 if temp == 0.0 else 1, + patience=1, + length_penalty=1, + max_length=self.transcriber.max_length, + suppress_blank=True, + suppress_tokens=suppress_tokens, + return_scores=True, + return_no_speech_prob=True, + sampling_temperature=temp, + repetition_penalty=1, + no_repeat_ngram_size=0, + ) + batch_results = self.transcriber.model.generate( + sub_encoder, sub_prompts, **gen_kwargs + ) + next_pending = [] + for j, idx in enumerate(pending_indices): + gen_result = batch_results[j] tokens = gen_result.sequences_ids[0] seq_len = len(tokens) cum_logprob = gen_result.scores[0] * seq_len avg_logprob = cum_logprob / (seq_len + 1) if seq_len > 0 else 0.0 + raw_text = tokenizers_list[idx].decode(tokens).strip() + comp_ratio = get_compression_ratio(raw_text) if raw_text else 0.0 + bad = ( + comp_ratio > comp_thresh + or avg_logprob < logprob_thresh + ) + # High no_speech + low logprob -> treat as silence, accept empty. + is_silence = ( + gen_result.no_speech_prob > no_speech_thresh + and avg_logprob < logprob_thresh + ) + + if not bad or is_silence or temp == temperatures[-1]: + final_results[idx] = (gen_result, avg_logprob, temp) + else: + next_pending.append(idx) + + pending_indices = next_pending + + # Step 5: Per-item segment parsing and result dispatch + for i, (req, features, audio, duration, speech_chunks) in enumerate(preprocessed): + try: + tokenizer = tokenizers_list[i] + gen_result, avg_logprob, used_temp = final_results[i] + + tokens = gen_result.sequences_ids[0] segment_size = int(ceil(duration) * self.transcriber.frames_per_second) subsegments, _, _ = self.transcriber._split_segments_by_timestamps( @@ -364,7 +417,7 @@ def _process_multi(self, batch: List[BatchRequest]): compression_ratio=get_compression_ratio(text), no_speech_prob=gen_result.no_speech_prob, words=None, - temperature=0.0, + temperature=used_temp, )) req.result = segments diff --git a/whisper_live/server.py b/whisper_live/server.py index 85046838..bb83f6b1 100644 --- a/whisper_live/server.py +++ b/whisper_live/server.py @@ -603,6 +603,9 @@ def run(self, logging.info("Custom model option was provided. Switching to single model mode.") self.single_model = True # TODO: load model initially + elif batch_enabled: + logging.info("Batch inference enabled. Switching to single model mode for stock model.") + self.single_model = True else: logging.info("Single model mode currently only works with custom models.") if not BackendType.is_valid(backend):