diff --git a/backend/.env.template b/backend/.env.template index b4397a8caf..02c8bc6d9d 100644 --- a/backend/.env.template +++ b/backend/.env.template @@ -13,6 +13,7 @@ REDIS_DB_PORT= REDIS_DB_PASSWORD= DEEPGRAM_API_KEY= +MODULATE_API_KEY= ADMIN_KEY= OPENAI_API_KEY= diff --git a/backend/charts/backend-listen/dev_omi_backend_listen_values.yaml b/backend/charts/backend-listen/dev_omi_backend_listen_values.yaml index e751fb88c5..f2e1f540b6 100644 --- a/backend/charts/backend-listen/dev_omi_backend_listen_values.yaml +++ b/backend/charts/backend-listen/dev_omi_backend_listen_values.yaml @@ -91,6 +91,11 @@ env: secretKeyRef: name: dev-omi-backend-secrets key: DEEPGRAM_API_KEY + - name: MODULATE_API_KEY + valueFrom: + secretKeyRef: + name: dev-omi-backend-secrets + key: MODULATE_API_KEY - name: FAL_KEY valueFrom: secretKeyRef: diff --git a/backend/charts/backend-listen/prod_omi_backend_listen_values.yaml b/backend/charts/backend-listen/prod_omi_backend_listen_values.yaml index 745a9279c9..90e5247ec2 100644 --- a/backend/charts/backend-listen/prod_omi_backend_listen_values.yaml +++ b/backend/charts/backend-listen/prod_omi_backend_listen_values.yaml @@ -143,6 +143,11 @@ env: secretKeyRef: name: prod-omi-backend-secrets key: DEEPGRAM_API_KEY + - name: MODULATE_API_KEY + valueFrom: + secretKeyRef: + name: prod-omi-backend-secrets + key: MODULATE_API_KEY - name: GOOGLE_MAPS_API_KEY valueFrom: secretKeyRef: diff --git a/backend/charts/backend-secrets/dev_omi_backend_secrets_values.yaml b/backend/charts/backend-secrets/dev_omi_backend_secrets_values.yaml index 7955f21ddc..d5c5225c23 100644 --- a/backend/charts/backend-secrets/dev_omi_backend_secrets_values.yaml +++ b/backend/charts/backend-secrets/dev_omi_backend_secrets_values.yaml @@ -1,82 +1,84 @@ -gke: - projectID: based-hardware-dev - clusterLocation: us-central1 - clusterName: dev-omi-gke - -gsa: - name: dev-omi-backend-eso-gsa@based-hardware-dev.iam.gserviceaccount.com - -serviceAccount: - name: dev-omi-backend-eso-ksa - -externalSecret: - name: dev-omi-backend-external-secret - targetSecretName: dev-omi-backend-secrets - refreshInterval: 1h - secretKeys: # secretKey is the key in the Kubernetes secret, remoteKey is the key in the Secrets Manager - - secretKey: HUGGINGFACE_TOKEN - remoteKey: HUGGINGFACE_TOKEN - - secretKey: DEEPGRAM_API_KEY - remoteKey: DEEPGRAM_API_KEY - - secretKey: FAL_KEY - remoteKey: FAL_KEY - - secretKey: OPENAI_API_KEY - remoteKey: OPENAI_API_KEY - - secretKey: GOOGLE_MAPS_API_KEY - remoteKey: GOOGLE_MAPS_API_KEY - - secretKey: GITHUB_TOKEN - remoteKey: GITHUB_TOKEN - - secretKey: PINECONE_API_KEY - remoteKey: PINECONE_API_KEY - - secretKey: REDIS_DB_HOST - remoteKey: REDIS_DB_HOST - - secretKey: REDIS_DB_PASSWORD - remoteKey: REDIS_DB_PASSWORD - - secretKey: ADMIN_KEY - remoteKey: ADMIN_KEY - - secretKey: GOOGLE_APPLICATION_CREDENTIALS - remoteKey: GOOGLE_APPLICATION_CREDENTIALS - - secretKey: DD_API_KEY - remoteKey: DD_API_KEY - - secretKey: LANGCHAIN_API_KEY - remoteKey: LANGCHAIN_API_KEY - - secretKey: STRIPE_API_KEY - remoteKey: STRIPE_API_KEY - - secretKey: STRIPE_WEBHOOK_SECRET - remoteKey: STRIPE_WEBHOOK_SECRET - - secretKey: MARKETPLACE_APP_REVIEWERS - remoteKey: MARKETPLACE_APP_REVIEWERS - - secretKey: TYPESENSE_HOST - remoteKey: TYPESENSE_HOST - - secretKey: TYPESENSE_API_KEY - remoteKey: TYPESENSE_API_KEY - - secretKey: STT_SERVICE_MODELS - remoteKey: STT_SERVICE_MODELS - - secretKey: ENCRYPTION_SECRET - remoteKey: ENCRYPTION_SECRET - - secretKey: SERVICE_ACCOUNT_JSON - remoteKey: SERVICE_ACCOUNT_JSON - - secretKey: GEMINI_API_KEY - remoteKey: GEMINI_API_KEY - - secretKey: TWILIO_ACCOUNT_SID - remoteKey: TWILIO_ACCOUNT_SID - - secretKey: TWILIO_AUTH_TOKEN - remoteKey: TWILIO_AUTH_TOKEN - - secretKey: TWILIO_API_KEY_SID - remoteKey: TWILIO_API_KEY_SID - - secretKey: TWILIO_API_KEY_SECRET - remoteKey: TWILIO_API_KEY_SECRET - - secretKey: TWILIO_TWIML_APP_SID - remoteKey: TWILIO_TWIML_APP_SID - - secretKey: METRICS_SECRET - remoteKey: METRICS_SECRET - - secretKey: GROQ_API_KEY - remoteKey: GROQ_API_KEY - - secretKey: OPENROUTER_API_KEY - remoteKey: OPENROUTER_API_KEY - - secretKey: RAPID_API_HOST - remoteKey: RAPID_API_HOST - - secretKey: RAPID_API_KEY - remoteKey: RAPID_API_KEY - - secretKey: CONVERSATION_SUMMARIZED_APP_IDS - remoteKey: CONVERSATION_SUMMARIZED_APP_IDS +gke: + projectID: based-hardware-dev + clusterLocation: us-central1 + clusterName: dev-omi-gke + +gsa: + name: dev-omi-backend-eso-gsa@based-hardware-dev.iam.gserviceaccount.com + +serviceAccount: + name: dev-omi-backend-eso-ksa + +externalSecret: + name: dev-omi-backend-external-secret + targetSecretName: dev-omi-backend-secrets + refreshInterval: 1h + secretKeys: # secretKey is the key in the Kubernetes secret, remoteKey is the key in the Secrets Manager + - secretKey: HUGGINGFACE_TOKEN + remoteKey: HUGGINGFACE_TOKEN + - secretKey: DEEPGRAM_API_KEY + remoteKey: DEEPGRAM_API_KEY + - secretKey: MODULATE_API_KEY + remoteKey: MODULATE_API_KEY + - secretKey: FAL_KEY + remoteKey: FAL_KEY + - secretKey: OPENAI_API_KEY + remoteKey: OPENAI_API_KEY + - secretKey: GOOGLE_MAPS_API_KEY + remoteKey: GOOGLE_MAPS_API_KEY + - secretKey: GITHUB_TOKEN + remoteKey: GITHUB_TOKEN + - secretKey: PINECONE_API_KEY + remoteKey: PINECONE_API_KEY + - secretKey: REDIS_DB_HOST + remoteKey: REDIS_DB_HOST + - secretKey: REDIS_DB_PASSWORD + remoteKey: REDIS_DB_PASSWORD + - secretKey: ADMIN_KEY + remoteKey: ADMIN_KEY + - secretKey: GOOGLE_APPLICATION_CREDENTIALS + remoteKey: GOOGLE_APPLICATION_CREDENTIALS + - secretKey: DD_API_KEY + remoteKey: DD_API_KEY + - secretKey: LANGCHAIN_API_KEY + remoteKey: LANGCHAIN_API_KEY + - secretKey: STRIPE_API_KEY + remoteKey: STRIPE_API_KEY + - secretKey: STRIPE_WEBHOOK_SECRET + remoteKey: STRIPE_WEBHOOK_SECRET + - secretKey: MARKETPLACE_APP_REVIEWERS + remoteKey: MARKETPLACE_APP_REVIEWERS + - secretKey: TYPESENSE_HOST + remoteKey: TYPESENSE_HOST + - secretKey: TYPESENSE_API_KEY + remoteKey: TYPESENSE_API_KEY + - secretKey: STT_SERVICE_MODELS + remoteKey: STT_SERVICE_MODELS + - secretKey: ENCRYPTION_SECRET + remoteKey: ENCRYPTION_SECRET + - secretKey: SERVICE_ACCOUNT_JSON + remoteKey: SERVICE_ACCOUNT_JSON + - secretKey: GEMINI_API_KEY + remoteKey: GEMINI_API_KEY + - secretKey: TWILIO_ACCOUNT_SID + remoteKey: TWILIO_ACCOUNT_SID + - secretKey: TWILIO_AUTH_TOKEN + remoteKey: TWILIO_AUTH_TOKEN + - secretKey: TWILIO_API_KEY_SID + remoteKey: TWILIO_API_KEY_SID + - secretKey: TWILIO_API_KEY_SECRET + remoteKey: TWILIO_API_KEY_SECRET + - secretKey: TWILIO_TWIML_APP_SID + remoteKey: TWILIO_TWIML_APP_SID + - secretKey: METRICS_SECRET + remoteKey: METRICS_SECRET + - secretKey: GROQ_API_KEY + remoteKey: GROQ_API_KEY + - secretKey: OPENROUTER_API_KEY + remoteKey: OPENROUTER_API_KEY + - secretKey: RAPID_API_HOST + remoteKey: RAPID_API_HOST + - secretKey: RAPID_API_KEY + remoteKey: RAPID_API_KEY + - secretKey: CONVERSATION_SUMMARIZED_APP_IDS + remoteKey: CONVERSATION_SUMMARIZED_APP_IDS diff --git a/backend/charts/backend-secrets/prod_omi_backend_secrets_values.yaml b/backend/charts/backend-secrets/prod_omi_backend_secrets_values.yaml index 098901cec3..4c1af2babc 100644 --- a/backend/charts/backend-secrets/prod_omi_backend_secrets_values.yaml +++ b/backend/charts/backend-secrets/prod_omi_backend_secrets_values.yaml @@ -1,80 +1,82 @@ -gke: - projectID: based-hardware - clusterLocation: us-central1 - clusterName: prod-omi-gke - -gsa: - name: prod-omi-backend-eso-gsa@based-hardware.iam.gserviceaccount.com - -serviceAccount: - name: prod-omi-backend-eso-ksa - -externalSecret: - name: prod-omi-backend-external-secret - targetSecretName: prod-omi-backend-secrets - refreshInterval: 1h - secretKeys: # secretKey is the key in the Kubernetes secret, remoteKey is the key in the Secrets Manager - - secretKey: HUGGINGFACE_TOKEN - remoteKey: HUGGINGFACE_TOKEN - - secretKey: GITHUB_TOKEN - remoteKey: GITHUB_TOKEN - - secretKey: OPENAI_API_KEY - remoteKey: OPENAI_API_KEY - - secretKey: FAL_KEY - remoteKey: FAL_KEY - - secretKey: GROQ_API_KEY - remoteKey: GROQ_API_KEY - - secretKey: PINECONE_API_KEY - remoteKey: PINECONE_API_KEY - - secretKey: REDIS_DB_HOST - remoteKey: REDIS_DB_HOST - - secretKey: REDIS_DB_PASSWORD - remoteKey: REDIS_DB_PASSWORD - - secretKey: ADMIN_KEY - remoteKey: ADMIN_KEY - - secretKey: DEEPGRAM_API_KEY - remoteKey: DEEPGRAM_API_KEY - - secretKey: GOOGLE_MAPS_API_KEY - remoteKey: GOOGLE_MAPS_API_KEY - - secretKey: GOOGLE_APPLICATION_CREDENTIALS - remoteKey: GOOGLE_APPLICATION_CREDENTIALS - - secretKey: LANGCHAIN_API_KEY - remoteKey: LANGCHAIN_API_KEY - - secretKey: DD_API_KEY - remoteKey: DD_API_KEY - - secretKey: STRIPE_API_KEY - remoteKey: STRIPE_API_KEY - - secretKey: STRIPE_WEBHOOK_SECRET - remoteKey: STRIPE_WEBHOOK_SECRET - - secretKey: MARKETPLACE_APP_REVIEWERS - remoteKey: MARKETPLACE_APP_REVIEWERS - - secretKey: TYPESENSE_HOST - remoteKey: TYPESENSE_HOST - - secretKey: TYPESENSE_API_KEY - remoteKey: TYPESENSE_API_KEY - - secretKey: RAPID_API_HOST - remoteKey: RAPID_API_HOST - - secretKey: RAPID_API_KEY - remoteKey: RAPID_API_KEY - - secretKey: OPENROUTER_API_KEY - remoteKey: OPENROUTER_API_KEY - - secretKey: CONVERSATION_SUMMARIZED_APP_IDS - remoteKey: CONVERSATION_SUMMARIZED_APP_IDS - - secretKey: STT_SERVICE_MODELS - remoteKey: STT_SERVICE_MODELS - - secretKey: ENCRYPTION_SECRET - remoteKey: ENCRYPTION_SECRET - - secretKey: GEMINI_API_KEY - remoteKey: GEMINI_API_KEY - - secretKey: TWILIO_ACCOUNT_SID - remoteKey: TWILIO_ACCOUNT_SID - - secretKey: TWILIO_AUTH_TOKEN - remoteKey: TWILIO_AUTH_TOKEN - - secretKey: TWILIO_API_KEY_SID - remoteKey: TWILIO_API_KEY_SID - - secretKey: TWILIO_API_KEY_SECRET - remoteKey: TWILIO_API_KEY_SECRET - - secretKey: TWILIO_TWIML_APP_SID - remoteKey: TWILIO_TWIML_APP_SID - - secretKey: METRICS_SECRET - remoteKey: METRICS_SECRET +gke: + projectID: based-hardware + clusterLocation: us-central1 + clusterName: prod-omi-gke + +gsa: + name: prod-omi-backend-eso-gsa@based-hardware.iam.gserviceaccount.com + +serviceAccount: + name: prod-omi-backend-eso-ksa + +externalSecret: + name: prod-omi-backend-external-secret + targetSecretName: prod-omi-backend-secrets + refreshInterval: 1h + secretKeys: # secretKey is the key in the Kubernetes secret, remoteKey is the key in the Secrets Manager + - secretKey: HUGGINGFACE_TOKEN + remoteKey: HUGGINGFACE_TOKEN + - secretKey: GITHUB_TOKEN + remoteKey: GITHUB_TOKEN + - secretKey: OPENAI_API_KEY + remoteKey: OPENAI_API_KEY + - secretKey: FAL_KEY + remoteKey: FAL_KEY + - secretKey: GROQ_API_KEY + remoteKey: GROQ_API_KEY + - secretKey: PINECONE_API_KEY + remoteKey: PINECONE_API_KEY + - secretKey: REDIS_DB_HOST + remoteKey: REDIS_DB_HOST + - secretKey: REDIS_DB_PASSWORD + remoteKey: REDIS_DB_PASSWORD + - secretKey: ADMIN_KEY + remoteKey: ADMIN_KEY + - secretKey: DEEPGRAM_API_KEY + remoteKey: DEEPGRAM_API_KEY + - secretKey: MODULATE_API_KEY + remoteKey: MODULATE_API_KEY + - secretKey: GOOGLE_MAPS_API_KEY + remoteKey: GOOGLE_MAPS_API_KEY + - secretKey: GOOGLE_APPLICATION_CREDENTIALS + remoteKey: GOOGLE_APPLICATION_CREDENTIALS + - secretKey: LANGCHAIN_API_KEY + remoteKey: LANGCHAIN_API_KEY + - secretKey: DD_API_KEY + remoteKey: DD_API_KEY + - secretKey: STRIPE_API_KEY + remoteKey: STRIPE_API_KEY + - secretKey: STRIPE_WEBHOOK_SECRET + remoteKey: STRIPE_WEBHOOK_SECRET + - secretKey: MARKETPLACE_APP_REVIEWERS + remoteKey: MARKETPLACE_APP_REVIEWERS + - secretKey: TYPESENSE_HOST + remoteKey: TYPESENSE_HOST + - secretKey: TYPESENSE_API_KEY + remoteKey: TYPESENSE_API_KEY + - secretKey: RAPID_API_HOST + remoteKey: RAPID_API_HOST + - secretKey: RAPID_API_KEY + remoteKey: RAPID_API_KEY + - secretKey: OPENROUTER_API_KEY + remoteKey: OPENROUTER_API_KEY + - secretKey: CONVERSATION_SUMMARIZED_APP_IDS + remoteKey: CONVERSATION_SUMMARIZED_APP_IDS + - secretKey: STT_SERVICE_MODELS + remoteKey: STT_SERVICE_MODELS + - secretKey: ENCRYPTION_SECRET + remoteKey: ENCRYPTION_SECRET + - secretKey: GEMINI_API_KEY + remoteKey: GEMINI_API_KEY + - secretKey: TWILIO_ACCOUNT_SID + remoteKey: TWILIO_ACCOUNT_SID + - secretKey: TWILIO_AUTH_TOKEN + remoteKey: TWILIO_AUTH_TOKEN + - secretKey: TWILIO_API_KEY_SID + remoteKey: TWILIO_API_KEY_SID + - secretKey: TWILIO_API_KEY_SECRET + remoteKey: TWILIO_API_KEY_SECRET + - secretKey: TWILIO_TWIML_APP_SID + remoteKey: TWILIO_TWIML_APP_SID + - secretKey: METRICS_SECRET + remoteKey: METRICS_SECRET diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 04d5e57afc..ff7d16f0fe 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -5,6 +5,7 @@ import logging import os import random +import audioop import struct import time import uuid @@ -74,8 +75,9 @@ STTService, get_stt_service_for_language, process_audio_dg, + process_audio_modulate, ) -from utils.stt.vad_gate import VADStreamingGate, VAD_GATE_MODE, is_gate_enabled +from utils.stt.vad_gate import GatedSTTSocket, VADStreamingGate, VAD_GATE_MODE, is_gate_enabled from utils.fair_use import ( FAIR_USE_ENABLED, FAIR_USE_CHECK_INTERVAL_SECONDS, @@ -925,25 +927,28 @@ def _update_in_progress_conversation( return # Process STT - deepgram_socket = None + stt_socket = None vad_gate = None def stream_transcript(segments): nonlocal realtime_segment_buffers - # Note: DG timestamp remapping is handled inside GatedDeepgramSocket wrapper realtime_segment_buffers.extend(segments) + async def _create_stt_socket(callback, lang, sr, model, kw=None, active_check=None): + if stt_service == STTService.modulate: + return await process_audio_modulate(callback, sr, lang) + return await process_audio_dg(callback, lang, sr, 1, model=model, keywords=kw, is_active=active_check) + async def _process_stt(): nonlocal websocket_close_code - nonlocal deepgram_socket + nonlocal stt_socket try: if use_custom_stt: logger.info(f"Custom STT mode enabled - using suggested transcripts from app {uid} {session_id}") return None if is_multi_channel: - # Create one STT connection per channel for i, ch_config in enumerate(channel_configs): def make_multi_channel_callback(cfg): @@ -956,23 +961,14 @@ def cb(segments): return cb callback = make_multi_channel_callback(ch_config) - stt_sockets_multi[i] = await process_audio_dg( - callback, - stt_language, - TARGET_SAMPLE_RATE, - 1, - model=stt_model, + stt_sockets_multi[i] = await _create_stt_socket( + callback, stt_language, TARGET_SAMPLE_RATE, stt_model ) logger.info( f"Multi-channel STT connections established ({len(channel_configs)} channels) {uid} {session_id}" ) return None - # Initialize VAD gate for all eligible DG sessions. - # Gate requires PCM16 LE (linear16). All codecs (opus, aac, lc3) - # decode to int16 before buffering. pcm8/pcm16 are linear16 from hardware - # (the "8"/"16" refers to sample rate kHz, not bit depth). - # DG always receives mono (channels=1), so clamp gate channels to 1. nonlocal vad_gate gate_enabled_by_override = vad_gate_override == 'enabled' gate_disabled_by_override = vad_gate_override == 'disabled' @@ -981,7 +977,7 @@ def cb(segments): try: vad_gate = VADStreamingGate( sample_rate=sample_rate, - channels=1, # DG always receives mono (encoding=linear16, channels=1) + channels=1, mode=gate_mode, uid=uid, session_id=session_id, @@ -998,16 +994,29 @@ def cb(segments): logger.exception('VAD gate init failed, continuing without gate uid=%s session=%s', uid, session_id) vad_gate = None - deepgram_socket = await process_audio_dg( - stream_transcript, + def _make_stream_callback(callback): + if vad_gate is not None: + + def wrapped(segments): + vad_gate.remap_segments(segments) + callback(segments) + + return wrapped + return callback + + raw_socket = await _create_stt_socket( + _make_stream_callback(stream_transcript), stt_language, sample_rate, - 1, - model=stt_model, - keywords=vocabulary[:100] if vocabulary else None, - vad_gate=vad_gate, - is_active=lambda: websocket_active, + stt_model, + kw=vocabulary[:100] if vocabulary else None, + active_check=lambda: websocket_active, ) + if vad_gate is not None and raw_socket is not None: + passthrough = stt_service == STTService.modulate + stt_socket = GatedSTTSocket(raw_socket, gate=vad_gate, passthrough_audio=passthrough) + else: + stt_socket = raw_socket return None except Exception as e: @@ -2306,7 +2315,7 @@ async def handle_image_chunk(uid: str, chunk_data: dict, image_chunks_cache: dic elif codec == 'lc3': lc3_decoder = lc3.Decoder(lc3_frame_duration_us, sample_rate) - async def receive_data(dg_socket): + async def receive_data(stt_socket): nonlocal websocket_active, websocket_close_code, last_audio_received_time, last_activity_time, current_conversation_id nonlocal realtime_photo_buffers, speaker_to_person_map, first_audio_byte_timestamp, last_usage_record_timestamp nonlocal audio_ring_buffer, dg_usage_ms_pending @@ -2319,7 +2328,7 @@ async def receive_data(dg_socket): stt_buffer_flush_size = int(sample_rate * 2 * 0.03) # 30ms at 16-bit mono (e.g., 6400 bytes at 16kHz) async def flush_stt_buffer(force: bool = False): - nonlocal stt_audio_buffer, dg_usage_ms_pending, dg_socket + nonlocal stt_audio_buffer, dg_usage_ms_pending, stt_socket if not stt_audio_buffer: return @@ -2329,24 +2338,21 @@ async def flush_stt_buffer(force: bool = False): chunk = bytes(stt_audio_buffer) stt_audio_buffer.clear() - # Check if DG connection died (keepalive or send failure) (#5870) - if dg_socket is not None and dg_socket.is_connection_dead: - close_reason = dg_socket.death_reason or 'unknown' + if stt_socket is not None and stt_socket.is_connection_dead: + close_reason = stt_socket.death_reason or 'unknown' logger.error( - 'DG connection died mid-session uid=%s session=%s reason=%s', + 'STT connection died mid-session uid=%s session=%s reason=%s', uid, session_id, close_reason, ) - dg_socket = None # Stop sending to dead connection + stt_socket = None - if dg_socket is not None: - # DG budget gate: skip sending if daily budget is exhausted (#5746, #6083) + if stt_socket is not None: if fair_use_dg_budget_exhausted: - pass # Audio not forwarded to DG — budget/credits exhausted + pass else: - dg_socket.send(chunk) - # Accumulate DG usage locally, flushed every 60s (#5854) + stt_socket.send(chunk) if fair_use_track_dg_usage: chunk_ms = len(chunk) * 1000 // (sample_rate * 2) # 16-bit mono dg_usage_ms_pending += chunk_ms @@ -2464,12 +2470,16 @@ async def flush_stt_buffer(force: bool = False): ) continue + if codec == 'pcm8': + data = audioop.bias(data, 1, -128) + data = audioop.lin2lin(data, 1, 2) + # Feed ring buffer for speaker identification (always, with wall-clock time) if audio_ring_buffer is not None: audio_ring_buffer.write(data, last_audio_received_time) if not use_custom_stt: - # VAD gating is handled inside GatedDeepgramSocket.send() + # VAD gating is handled inside GatedSTTSocket.send() stt_audio_buffer.extend(data) await flush_stt_buffer() @@ -2561,6 +2571,21 @@ async def flush_stt_buffer(force: bool = False): # Flush any remaining audio in buffer to STT if not use_custom_stt: await flush_stt_buffer(force=True) + # EOS drain: send EOS and wait for final transcripts while + # stream_transcript_process is still running (before websocket_active=False) + try: + if is_multi_channel: + for mc_stt_socket in stt_sockets_multi: + if mc_stt_socket and hasattr(mc_stt_socket, 'drain_and_close'): + await mc_stt_socket.drain_and_close() + else: + drain_target = stt_socket + if isinstance(stt_socket, GatedSTTSocket): + drain_target = stt_socket._conn + if drain_target and hasattr(drain_target, 'drain_and_close'): + await drain_target.drain_and_close() + except Exception as e: + logger.error(f"Error draining STT EOS: {e} {uid} {session_id}") websocket_active = False # Start @@ -2614,7 +2639,7 @@ async def flush_stt_buffer(force: bool = False): pusher_tasks.append(asyncio.create_task(pusher_heartbeat())) # Tasks - data_process_task = asyncio.create_task(receive_data(deepgram_socket)) + data_process_task = asyncio.create_task(receive_data(stt_socket)) stream_transcript_task = asyncio.create_task(stream_transcript_process()) record_usage_task = asyncio.create_task(_record_usage_periodically()) @@ -2687,9 +2712,8 @@ async def flush_stt_buffer(force: bool = False): if mc_stt_socket: mc_stt_socket.finish() else: - if deepgram_socket: - # GatedDeepgramSocket.finish() handles finalize automatically - deepgram_socket.finish() + if stt_socket: + stt_socket.finish() except Exception as e: logger.error(f"Error closing STT sockets: {e} {uid} {session_id}") diff --git a/backend/scripts/stt/l_benchmark_prerecorded.py b/backend/scripts/stt/l_benchmark_prerecorded.py new file mode 100644 index 0000000000..61767f0c59 --- /dev/null +++ b/backend/scripts/stt/l_benchmark_prerecorded.py @@ -0,0 +1,287 @@ +""" +Benchmark: Deepgram vs Modulate — Pre-recorded transcription. + +Generates 10+ diverse test audio samples and runs both providers, +measuring latency, word count, and WER against reference text. + +Usage: + cd backend && python scripts/stt/l_benchmark_prerecorded.py +""" + +import asyncio +import json +import os +import subprocess +import sys +import time +from pathlib import Path +from typing import List, Tuple + +from dotenv import load_dotenv + +load_dotenv(Path(__file__).resolve().parents[2] / '.env') + +from jiwer import wer as compute_wer +from tabulate import tabulate + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from utils.stt.pre_recorded import deepgram_prerecorded_from_bytes, modulate_prerecorded_from_bytes + +AUDIO_DIR = Path('/tmp/stt_benchmark_audio') +RESULTS_DIR = Path('/tmp/stt_benchmark_results') + +BENCHMARK_CASES: List[dict] = [ + { + 'id': 'short_greeting', + 'text': 'Hello, how are you doing today?', + 'lang': 'en', + 'description': 'Short greeting (5 words)', + }, + { + 'id': 'medium_sentence', + 'text': 'The quick brown fox jumps over the lazy dog near the old oak tree in the park.', + 'lang': 'en', + 'description': 'Medium sentence (17 words)', + }, + { + 'id': 'technical_jargon', + 'text': 'The server processes incoming websocket connections on port eight thousand and eighty, using TLS encryption for secure data transmission.', + 'lang': 'en', + 'description': 'Technical content with numbers', + }, + { + 'id': 'conversational', + 'text': "Well, I think we should probably go to the store and pick up some groceries before it closes. What do you think?", + 'lang': 'en', + 'description': 'Conversational with fillers', + }, + { + 'id': 'numbers_dates', + 'text': 'The meeting is scheduled for January fifteenth, twenty twenty six, at three thirty in the afternoon.', + 'lang': 'en', + 'description': 'Numbers and dates', + }, + { + 'id': 'medical_terms', + 'text': 'The patient was diagnosed with bilateral pneumonia and prescribed amoxicillin for ten days along with regular monitoring.', + 'lang': 'en', + 'description': 'Medical terminology', + }, + { + 'id': 'long_paragraph', + 'text': ( + 'Artificial intelligence has transformed many industries over the past decade. ' + 'Machine learning models can now understand natural language, generate images, ' + 'and even write code. However, there are still significant challenges in ensuring ' + 'these systems are reliable, safe, and aligned with human values.' + ), + 'lang': 'en', + 'description': 'Long paragraph (40+ words)', + }, + { + 'id': 'names_places', + 'text': 'Doctor Sarah Chen from Stanford University presented her findings at the conference in San Francisco, California.', + 'lang': 'en', + 'description': 'Proper nouns (names, places)', + }, + { + 'id': 'question_answer', + 'text': "What is the capital of France? The capital of France is Paris, which is located along the Seine River.", + 'lang': 'en', + 'description': 'Question and answer format', + }, + { + 'id': 'instructions', + 'text': ( + 'First, open the application settings. Then navigate to the audio section. ' + 'Select the input device and set the sample rate to sixteen thousand hertz. ' + 'Finally, click save to apply your changes.' + ), + 'lang': 'en', + 'description': 'Step-by-step instructions', + }, + { + 'id': 'emotional_speech', + 'text': "This is absolutely incredible! I can't believe we finally got it working after all these months of effort.", + 'lang': 'en', + 'description': 'Emotional/exclamatory speech', + }, + { + 'id': 'multi_speaker_sim', + 'text': ( + 'Good morning everyone. Today we will discuss the quarterly results. ' + 'Revenue increased by fifteen percent compared to last quarter. ' + 'Our customer satisfaction scores have also improved significantly.' + ), + 'lang': 'en', + 'description': 'Meeting-style multi-sentence', + }, +] + + +def generate_audio(case: dict, output_path: Path) -> None: + tmp_raw = output_path.with_suffix('.raw.wav') + subprocess.run( + ['espeak-ng', '-v', case['lang'], '-w', str(tmp_raw), '--', case['text']], + check=True, + capture_output=True, + ) + subprocess.run( + ['ffmpeg', '-y', '-i', str(tmp_raw), '-ar', '16000', '-ac', '1', '-sample_fmt', 's16', str(output_path)], + check=True, + capture_output=True, + ) + tmp_raw.unlink(missing_ok=True) + + +def run_deepgram(audio_bytes: bytes) -> Tuple[str, float, int]: + t0 = time.monotonic() + result = deepgram_prerecorded_from_bytes(audio_bytes, sample_rate=16000, diarize=True) + elapsed = time.monotonic() - t0 + text = ' '.join(w.get('text', '') or w.get('word', '') for w in result).strip() + return text, elapsed, len(result) + + +def run_modulate(audio_bytes: bytes) -> Tuple[str, float, int]: + t0 = time.monotonic() + result = modulate_prerecorded_from_bytes(audio_bytes, sample_rate=16000, diarize=True) + elapsed = time.monotonic() - t0 + text = ' '.join(w.get('text', '') for w in result).strip() + return text, elapsed, len(result) + + +def main(): + AUDIO_DIR.mkdir(parents=True, exist_ok=True) + RESULTS_DIR.mkdir(parents=True, exist_ok=True) + + dg_key = os.getenv('DEEPGRAM_API_KEY') + mod_key = os.getenv('MODULATE_API_KEY') + if not dg_key: + print('ERROR: DEEPGRAM_API_KEY not set') + sys.exit(1) + if not mod_key: + print('ERROR: MODULATE_API_KEY not set') + sys.exit(1) + + print(f'Generating {len(BENCHMARK_CASES)} test audio samples...') + for case in BENCHMARK_CASES: + wav_path = AUDIO_DIR / f"{case['id']}.wav" + if not wav_path.exists(): + generate_audio(case, wav_path) + print(f" Generated: {case['id']} ({wav_path.stat().st_size / 1024:.1f} KB)") + else: + print(f" Cached: {case['id']}") + + print(f'\nRunning pre-recorded benchmarks ({len(BENCHMARK_CASES)} cases x 2 providers)...\n') + + results = [] + for case in BENCHMARK_CASES: + wav_path = AUDIO_DIR / f"{case['id']}.wav" + audio_bytes = wav_path.read_bytes() + ref_text = case['text'].lower() + + row = { + 'id': case['id'], + 'description': case['description'], + 'ref_words': len(case['text'].split()), + 'audio_kb': wav_path.stat().st_size / 1024, + } + + print(f" [{case['id']}] {case['description']}") + + try: + dg_text, dg_time, dg_segments = run_deepgram(audio_bytes) + dg_wer = compute_wer(ref_text, dg_text.lower()) if dg_text else 1.0 + row.update( + { + 'dg_time': dg_time, + 'dg_words': len(dg_text.split()) if dg_text else 0, + 'dg_wer': dg_wer, + 'dg_text': dg_text, + 'dg_segments': dg_segments, + } + ) + print(f" Deepgram: {dg_time:.2f}s WER={dg_wer:.2%} words={len(dg_text.split())}") + except Exception as e: + print(f" Deepgram: ERROR - {e}") + row.update({'dg_time': -1, 'dg_words': 0, 'dg_wer': 1.0, 'dg_text': f'ERROR: {e}', 'dg_segments': 0}) + + try: + mod_text, mod_time, mod_segments = run_modulate(audio_bytes) + mod_wer = compute_wer(ref_text, mod_text.lower()) if mod_text else 1.0 + row.update( + { + 'mod_time': mod_time, + 'mod_words': len(mod_text.split()) if mod_text else 0, + 'mod_wer': mod_wer, + 'mod_text': mod_text, + 'mod_segments': mod_segments, + } + ) + print(f" Modulate: {mod_time:.2f}s WER={mod_wer:.2%} words={len(mod_text.split())}") + except Exception as e: + print(f" Modulate: ERROR - {e}") + row.update({'mod_time': -1, 'mod_words': 0, 'mod_wer': 1.0, 'mod_text': f'ERROR: {e}', 'mod_segments': 0}) + + results.append(row) + + print('\n' + '=' * 100) + print('PRE-RECORDED BENCHMARK RESULTS') + print('=' * 100) + + table_data = [] + for r in results: + table_data.append( + [ + r['id'], + r['ref_words'], + f"{r['audio_kb']:.1f}", + f"{r.get('dg_time', -1):.2f}s" if r.get('dg_time', -1) >= 0 else 'ERR', + f"{r.get('dg_wer', 1):.1%}", + r.get('dg_words', 0), + f"{r.get('mod_time', -1):.2f}s" if r.get('mod_time', -1) >= 0 else 'ERR', + f"{r.get('mod_wer', 1):.1%}", + r.get('mod_words', 0), + ] + ) + + print( + tabulate( + table_data, + headers=[ + 'Case', + 'Ref Words', + 'Audio KB', + 'DG Time', + 'DG WER', + 'DG Words', + 'Mod Time', + 'Mod WER', + 'Mod Words', + ], + tablefmt='grid', + ) + ) + + valid_dg = [r for r in results if r.get('dg_time', -1) >= 0] + valid_mod = [r for r in results if r.get('mod_time', -1) >= 0] + + print('\nSUMMARY:') + if valid_dg: + avg_dg_time = sum(r['dg_time'] for r in valid_dg) / len(valid_dg) + avg_dg_wer = sum(r['dg_wer'] for r in valid_dg) / len(valid_dg) + print(f" Deepgram: avg_latency={avg_dg_time:.2f}s avg_WER={avg_dg_wer:.1%} cases={len(valid_dg)}") + if valid_mod: + avg_mod_time = sum(r['mod_time'] for r in valid_mod) / len(valid_mod) + avg_mod_wer = sum(r['mod_wer'] for r in valid_mod) / len(valid_mod) + print(f" Modulate: avg_latency={avg_mod_time:.2f}s avg_WER={avg_mod_wer:.1%} cases={len(valid_mod)}") + + output_path = RESULTS_DIR / 'prerecorded_benchmark.json' + with open(output_path, 'w') as f: + json.dump(results, f, indent=2) + print(f'\nDetailed results saved to: {output_path}') + + +if __name__ == '__main__': + main() diff --git a/backend/scripts/stt/m_benchmark_streaming.py b/backend/scripts/stt/m_benchmark_streaming.py new file mode 100644 index 0000000000..f8b243ae22 --- /dev/null +++ b/backend/scripts/stt/m_benchmark_streaming.py @@ -0,0 +1,437 @@ +""" +Benchmark: Deepgram vs Modulate — Streaming transcription. + +Generates 10+ diverse test audio samples and streams them through both +providers, measuring connection latency, first-segment latency, total +segments, transcription time, and WER against reference text. + +Usage: + cd backend && python scripts/stt/m_benchmark_streaming.py +""" + +import asyncio +import json +import os +import subprocess +import sys +import time +from pathlib import Path +from typing import List, Tuple + +from dotenv import load_dotenv + +load_dotenv(Path(__file__).resolve().parents[2] / '.env') + +from jiwer import wer as compute_wer +from tabulate import tabulate + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from utils.stt.streaming import process_audio_dg, process_audio_modulate + +AUDIO_DIR = Path('/tmp/stt_benchmark_audio') +RESULTS_DIR = Path('/tmp/stt_benchmark_results') + +BENCHMARK_CASES: List[dict] = [ + { + 'id': 'short_greeting', + 'text': 'Hello, how are you doing today?', + 'lang': 'en', + 'description': 'Short greeting (5 words)', + }, + { + 'id': 'medium_sentence', + 'text': 'The quick brown fox jumps over the lazy dog near the old oak tree in the park.', + 'lang': 'en', + 'description': 'Medium sentence (17 words)', + }, + { + 'id': 'technical_jargon', + 'text': 'The server processes incoming websocket connections on port eight thousand and eighty, using TLS encryption for secure data transmission.', + 'lang': 'en', + 'description': 'Technical content with numbers', + }, + { + 'id': 'conversational', + 'text': "Well, I think we should probably go to the store and pick up some groceries before it closes. What do you think?", + 'lang': 'en', + 'description': 'Conversational with fillers', + }, + { + 'id': 'numbers_dates', + 'text': 'The meeting is scheduled for January fifteenth, twenty twenty six, at three thirty in the afternoon.', + 'lang': 'en', + 'description': 'Numbers and dates', + }, + { + 'id': 'medical_terms', + 'text': 'The patient was diagnosed with bilateral pneumonia and prescribed amoxicillin for ten days along with regular monitoring.', + 'lang': 'en', + 'description': 'Medical terminology', + }, + { + 'id': 'long_paragraph', + 'text': ( + 'Artificial intelligence has transformed many industries over the past decade. ' + 'Machine learning models can now understand natural language, generate images, ' + 'and even write code. However, there are still significant challenges in ensuring ' + 'these systems are reliable, safe, and aligned with human values.' + ), + 'lang': 'en', + 'description': 'Long paragraph (40+ words)', + }, + { + 'id': 'names_places', + 'text': 'Doctor Sarah Chen from Stanford University presented her findings at the conference in San Francisco, California.', + 'lang': 'en', + 'description': 'Proper nouns (names, places)', + }, + { + 'id': 'question_answer', + 'text': "What is the capital of France? The capital of France is Paris, which is located along the Seine River.", + 'lang': 'en', + 'description': 'Question and answer format', + }, + { + 'id': 'instructions', + 'text': ( + 'First, open the application settings. Then navigate to the audio section. ' + 'Select the input device and set the sample rate to sixteen thousand hertz. ' + 'Finally, click save to apply your changes.' + ), + 'lang': 'en', + 'description': 'Step-by-step instructions', + }, + { + 'id': 'emotional_speech', + 'text': "This is absolutely incredible! I can't believe we finally got it working after all these months of effort.", + 'lang': 'en', + 'description': 'Emotional/exclamatory speech', + }, + { + 'id': 'multi_speaker_sim', + 'text': ( + 'Good morning everyone. Today we will discuss the quarterly results. ' + 'Revenue increased by fifteen percent compared to last quarter. ' + 'Our customer satisfaction scores have also improved significantly.' + ), + 'lang': 'en', + 'description': 'Meeting-style multi-sentence', + }, +] + +CHUNK_SIZE = 3200 +CHUNK_INTERVAL = 0.1 + + +def generate_audio(case: dict, output_path: Path) -> None: + tmp_raw = output_path.with_suffix('.raw.wav') + subprocess.run( + ['espeak-ng', '-v', case['lang'], '-w', str(tmp_raw), '--', case['text']], + check=True, + capture_output=True, + ) + subprocess.run( + ['ffmpeg', '-y', '-i', str(tmp_raw), '-ar', '16000', '-ac', '1', '-sample_fmt', 's16', str(output_path)], + check=True, + capture_output=True, + ) + tmp_raw.unlink(missing_ok=True) + + +def read_pcm_from_wav(wav_path: Path) -> bytes: + data = wav_path.read_bytes() + if data[:4] == b'RIFF': + return data[44:] + return data + + +async def stream_to_deepgram(audio_pcm: bytes, language: str) -> dict: + segments_received = [] + first_segment_time = [None] + connect_start = time.monotonic() + + def stream_transcript(segments): + if first_segment_time[0] is None: + first_segment_time[0] = time.monotonic() + segments_received.extend(segments) + + try: + socket = await asyncio.wait_for( + process_audio_dg(stream_transcript, language, 16000, 1, model='nova-3'), + timeout=15, + ) + except Exception as e: + return {'error': str(e), 'connect_time': -1} + + connect_time = time.monotonic() - connect_start + stream_start = time.monotonic() + + offset = 0 + while offset < len(audio_pcm): + chunk = audio_pcm[offset : offset + CHUNK_SIZE] + socket.send(chunk) + offset += CHUNK_SIZE + await asyncio.sleep(CHUNK_INTERVAL) + + socket.finish() + await asyncio.sleep(3) + + total_time = time.monotonic() - stream_start + text = ' '.join(s.get('text', '') for s in segments_received).strip() + first_seg_latency = (first_segment_time[0] - stream_start) if first_segment_time[0] else -1 + + return { + 'connect_time': connect_time, + 'first_segment_latency': first_seg_latency, + 'total_time': total_time, + 'segments': len(segments_received), + 'text': text, + 'words': len(text.split()) if text else 0, + } + + +async def stream_to_modulate(audio_pcm: bytes, language: str) -> dict: + segments_received = [] + first_segment_time = [None] + connect_start = time.monotonic() + + def stream_transcript(segments): + if first_segment_time[0] is None: + first_segment_time[0] = time.monotonic() + segments_received.extend(segments) + + try: + socket = await asyncio.wait_for( + process_audio_modulate(stream_transcript, 16000, language), + timeout=15, + ) + except Exception as e: + return {'error': str(e), 'connect_time': -1} + + connect_time = time.monotonic() - connect_start + stream_start = time.monotonic() + + offset = 0 + while offset < len(audio_pcm): + chunk = audio_pcm[offset : offset + CHUNK_SIZE] + socket.send(chunk) + offset += CHUNK_SIZE + await asyncio.sleep(CHUNK_INTERVAL) + + try: + await asyncio.wait_for(socket.drain_and_close(), timeout=20) + except (asyncio.TimeoutError, Exception): + pass + + total_time = time.monotonic() - stream_start + text = ' '.join(s.get('text', '') for s in segments_received).strip() + first_seg_latency = (first_segment_time[0] - stream_start) if first_segment_time[0] else -1 + + return { + 'connect_time': connect_time, + 'first_segment_latency': first_seg_latency, + 'total_time': total_time, + 'segments': len(segments_received), + 'text': text, + 'words': len(text.split()) if text else 0, + } + + +async def run_benchmark(): + AUDIO_DIR.mkdir(parents=True, exist_ok=True) + RESULTS_DIR.mkdir(parents=True, exist_ok=True) + + dg_key = os.getenv('DEEPGRAM_API_KEY') + mod_key = os.getenv('MODULATE_API_KEY') + if not dg_key: + print('ERROR: DEEPGRAM_API_KEY not set') + sys.exit(1) + if not mod_key: + print('ERROR: MODULATE_API_KEY not set') + sys.exit(1) + + print(f'Generating {len(BENCHMARK_CASES)} test audio samples...') + for case in BENCHMARK_CASES: + wav_path = AUDIO_DIR / f"{case['id']}.wav" + if not wav_path.exists(): + generate_audio(case, wav_path) + print(f" Generated: {case['id']} ({wav_path.stat().st_size / 1024:.1f} KB)") + else: + print(f" Cached: {case['id']}") + + print(f'\nRunning streaming benchmarks ({len(BENCHMARK_CASES)} cases x 2 providers)...\n') + + results = [] + for case in BENCHMARK_CASES: + wav_path = AUDIO_DIR / f"{case['id']}.wav" + audio_pcm = read_pcm_from_wav(wav_path) + ref_text = case['text'].lower() + lang = case['lang'] + + row = { + 'id': case['id'], + 'description': case['description'], + 'ref_words': len(case['text'].split()), + 'audio_kb': len(audio_pcm) / 1024, + } + + print(f" [{case['id']}] {case['description']}") + + try: + dg_result = await stream_to_deepgram(audio_pcm, lang) + if 'error' in dg_result: + raise RuntimeError(dg_result['error']) + dg_wer = compute_wer(ref_text, dg_result['text'].lower()) if dg_result['text'] else 1.0 + row.update( + { + 'dg_connect': dg_result['connect_time'], + 'dg_first_seg': dg_result['first_segment_latency'], + 'dg_total': dg_result['total_time'], + 'dg_segments': dg_result['segments'], + 'dg_words': dg_result['words'], + 'dg_wer': dg_wer, + 'dg_text': dg_result['text'], + } + ) + print( + f" Deepgram: connect={dg_result['connect_time']:.2f}s " + f"first_seg={dg_result['first_segment_latency']:.2f}s " + f"total={dg_result['total_time']:.2f}s " + f"segs={dg_result['segments']} WER={dg_wer:.2%}" + ) + except Exception as e: + print(f" Deepgram: ERROR - {e}") + row.update( + { + 'dg_connect': -1, + 'dg_first_seg': -1, + 'dg_total': -1, + 'dg_segments': 0, + 'dg_words': 0, + 'dg_wer': 1.0, + 'dg_text': f'ERROR: {e}', + } + ) + + try: + mod_result = await stream_to_modulate(audio_pcm, lang) + if 'error' in mod_result: + raise RuntimeError(mod_result['error']) + mod_wer = compute_wer(ref_text, mod_result['text'].lower()) if mod_result['text'] else 1.0 + row.update( + { + 'mod_connect': mod_result['connect_time'], + 'mod_first_seg': mod_result['first_segment_latency'], + 'mod_total': mod_result['total_time'], + 'mod_segments': mod_result['segments'], + 'mod_words': mod_result['words'], + 'mod_wer': mod_wer, + 'mod_text': mod_result['text'], + } + ) + print( + f" Modulate: connect={mod_result['connect_time']:.2f}s " + f"first_seg={mod_result['first_segment_latency']:.2f}s " + f"total={mod_result['total_time']:.2f}s " + f"segs={mod_result['segments']} WER={mod_wer:.2%}" + ) + except Exception as e: + print(f" Modulate: ERROR - {e}") + row.update( + { + 'mod_connect': -1, + 'mod_first_seg': -1, + 'mod_total': -1, + 'mod_segments': 0, + 'mod_words': 0, + 'mod_wer': 1.0, + 'mod_text': f'ERROR: {e}', + } + ) + + results.append(row) + + print('\n' + '=' * 120) + print('STREAMING BENCHMARK RESULTS') + print('=' * 120) + + table_data = [] + for r in results: + + def fmt_time(v): + return f"{v:.2f}s" if v >= 0 else 'ERR' + + table_data.append( + [ + r['id'], + r['ref_words'], + fmt_time(r.get('dg_connect', -1)), + fmt_time(r.get('dg_first_seg', -1)), + fmt_time(r.get('dg_total', -1)), + r.get('dg_segments', 0), + f"{r.get('dg_wer', 1):.0%}", + fmt_time(r.get('mod_connect', -1)), + fmt_time(r.get('mod_first_seg', -1)), + fmt_time(r.get('mod_total', -1)), + r.get('mod_segments', 0), + f"{r.get('mod_wer', 1):.0%}", + ] + ) + + print( + tabulate( + table_data, + headers=[ + 'Case', + 'Words', + 'DG Conn', + 'DG 1st Seg', + 'DG Total', + 'DG Segs', + 'DG WER', + 'Mod Conn', + 'Mod 1st Seg', + 'Mod Total', + 'Mod Segs', + 'Mod WER', + ], + tablefmt='grid', + ) + ) + + valid_dg = [r for r in results if r.get('dg_total', -1) >= 0] + valid_mod = [r for r in results if r.get('mod_total', -1) >= 0] + + print('\nSUMMARY:') + if valid_dg: + print( + f" Deepgram: " + f"avg_connect={sum(r['dg_connect'] for r in valid_dg) / len(valid_dg):.2f}s " + f"avg_first_seg={sum(r['dg_first_seg'] for r in valid_dg if r['dg_first_seg'] >= 0) / max(1, len([r for r in valid_dg if r['dg_first_seg'] >= 0])):.2f}s " + f"avg_total={sum(r['dg_total'] for r in valid_dg) / len(valid_dg):.2f}s " + f"avg_WER={sum(r['dg_wer'] for r in valid_dg) / len(valid_dg):.1%} " + f"cases={len(valid_dg)}" + ) + if valid_mod: + print( + f" Modulate: " + f"avg_connect={sum(r['mod_connect'] for r in valid_mod) / len(valid_mod):.2f}s " + f"avg_first_seg={sum(r['mod_first_seg'] for r in valid_mod if r['mod_first_seg'] >= 0) / max(1, len([r for r in valid_mod if r['mod_first_seg'] >= 0])):.2f}s " + f"avg_total={sum(r['mod_total'] for r in valid_mod) / len(valid_mod):.2f}s " + f"avg_WER={sum(r['mod_wer'] for r in valid_mod) / len(valid_mod):.1%} " + f"cases={len(valid_mod)}" + ) + + output_path = RESULTS_DIR / 'streaming_benchmark.json' + with open(output_path, 'w') as f: + json.dump(results, f, indent=2) + print(f'\nDetailed results saved to: {output_path}') + + +def main(): + asyncio.run(run_benchmark()) + + +if __name__ == '__main__': + main() diff --git a/backend/scripts/stt/modulate_repro/README.md b/backend/scripts/stt/modulate_repro/README.md new file mode 100644 index 0000000000..9ccdd3ba50 --- /dev/null +++ b/backend/scripts/stt/modulate_repro/README.md @@ -0,0 +1,72 @@ +# Modulate Velma-2: Non-deterministic utterance ordering + +## Issue + +Sending the **same audio** to Modulate's Velma-2 streaming API with **identical parameters** produces utterances in **different order** across runs. + +## Test audio + +`test_audio.wav` — 38s, 16kHz, mono, PCM16. Contains 4 spoken utterances from LibriSpeech (public domain) separated by 5 seconds of silence: + +| # | Utterance | Source | +|---|-----------|--------| +| 1 | "He hoped there would be stew for dinner, turnips and carrots..." | LibriSpeech test-clean/1089/134686/0000 | +| 2 | "Stuff it into you, his belly counselled him." | LibriSpeech test-clean/1089/134686/0001 | +| 3 | "After early nightfall the yellow lamps would light up..." | LibriSpeech test-clean/1089/134686/0002 | +| 4 | "Hello Bertie, any good in your mind?" | LibriSpeech test-clean/1089/134686/0003 | + +**Download:** If the WAV is not included in your copy, download from GCS: +```bash +curl -o test_audio.wav "https://storage.googleapis.com/omi-pr-assets/modulate-repro/test_audio.wav" +``` + +## Reproduce + +```bash +pip install websockets +export MODULATE_API_KEY=your_key_here +python repro_utterance_order.py --runs 5 +``` + +Or pass the key directly: +```bash +python repro_utterance_order.py --api-key YOUR_KEY --runs 5 +``` + +## Expected + +All runs return utterances in order: `1→2→3→4` + +## Observed + +Order varies between runs. Example from 5 identical runs: + +``` +Run 1: [1→2→3→4] CORRECT +Run 2: [2→1→3→4] WRONG ORDER +Run 3: [2→1→3→4] WRONG ORDER +Run 4: [1→2→3→4] CORRECT +Run 5: [2→1→3→4] WRONG ORDER + +Order consistency: 2/5 runs correct (40%) +``` + +## Impact + +This non-deterministic ordering causes WER measurements on the same audio to swing wildly between runs (e.g., 5% to 75% WER on identical input) because WER is computed on the concatenated utterance text. + +The `start_ms` timestamps in the returned utterances are correct — utterance 1 always has the earliest `start_ms`. But the **arrival order** over the WebSocket is not guaranteed to match the temporal order. + +## API parameters used + +``` +wss://modulate-developer-apis.com/api/velma-2-stt-streaming + ?speaker_diarization=true + &partial_results=true + &sample_rate=16000 + &audio_format=s16le + &num_channels=1 + &language=en +``` + +Audio streamed in 100ms chunks (3200 bytes) at real-time pacing. diff --git a/backend/scripts/stt/modulate_repro/repro_utterance_order.py b/backend/scripts/stt/modulate_repro/repro_utterance_order.py new file mode 100644 index 0000000000..0481fb2cf9 --- /dev/null +++ b/backend/scripts/stt/modulate_repro/repro_utterance_order.py @@ -0,0 +1,238 @@ +""" +Modulate Velma-2 STT: Non-deterministic utterance ordering reproduction. + +Sends the SAME WAV file to Modulate's streaming API multiple times. +Demonstrates that utterance arrival order varies between identical runs, +causing inconsistent WER measurements. + +The test WAV contains 4 spoken utterances separated by 5s silence: + 1. "He hoped there would be stew for dinner, turnips and carrots..." + 2. "Stuff it into you, his belly counselled him." + 3. "After early nightfall the yellow lamps would light up..." + 4. "Hello Bertie, any good in your mind?" + +Expected: utterances arrive in order 1→2→3→4 every time. +Observed: order varies between runs (e.g., 2→1→3→4 or 1→2→3→4). + +Requirements: + pip install websockets + +Usage: + python repro_utterance_order.py # 5 runs + python repro_utterance_order.py --runs 10 # 10 runs + python repro_utterance_order.py --api-key YOUR_KEY # custom key +""" + +import argparse +import asyncio +import json +import os +import struct +import sys +import time +import urllib.parse +from pathlib import Path + +import websockets + +SCRIPT_DIR = Path(__file__).parent +DEFAULT_WAV = SCRIPT_DIR / 'test_audio.wav' + +EXPECTED_UTTERANCE_ORDER = [ + 'He hoped there would be stew for dinner', + 'Stuff it into you', + 'After early nightfall', + 'Hello Bertie', +] + + +def read_wav_pcm(wav_path): + """Read WAV file and return (sample_rate, pcm_bytes).""" + with open(wav_path, 'rb') as f: + riff = f.read(4) + if riff != b'RIFF': + raise ValueError(f'Not a WAV file: {wav_path}') + f.read(4) # file size + wave = f.read(4) + if wave != b'WAVE': + raise ValueError(f'Not a WAV file: {wav_path}') + + sample_rate = 16000 + while True: + chunk_id = f.read(4) + if len(chunk_id) < 4: + break + chunk_size = struct.unpack(' str: + return PUNCT_RE.sub('', text).lower().strip() + + +def count_punctuation(text: str) -> dict: + marks = re.findall(r'[^\w\s]', text) + return {'total': len(marks), 'detail': dict(sorted(((m, marks.count(m)) for m in set(marks)), key=lambda x: -x[1]))} + + +AUDIO_DIR = Path('/tmp/stt_benchmark_audio_02') +RESULTS_DIR = Path('/tmp/stt_benchmark_results') +LIBRISPEECH_TAR = Path('/tmp/test-clean.tar.gz') +LIBRISPEECH_DIR = Path('/tmp/librispeech/LibriSpeech/test-clean') + +SAMPLE_PICKS = [ + {'uid': '5683-32865-0000', 'desc': 'Short utterance (4 words, 2.2s)'}, + {'uid': '672-122797-0057', 'desc': 'Short sentence (7 words, 6.6s)'}, + {'uid': '2830-3980-0027', 'desc': 'Short phrase (8 words, 2.3s)'}, + {'uid': '3570-5694-0004', 'desc': 'Medium sentence (16 words, 5.3s)'}, + {'uid': '5142-33396-0012', 'desc': 'Medium dialog (18 words, 4.6s)'}, + {'uid': '8463-287645-0008', 'desc': 'Medium phrase (10 words, 3.3s)'}, + {'uid': '1580-141084-0024', 'desc': 'Long narrative (27 words, 9.2s)'}, + {'uid': '4970-29093-0019', 'desc': 'Long narrative (23 words, 7.5s)'}, + {'uid': '1284-1180-0006', 'desc': 'Long descriptive (22 words, 6.9s)'}, + {'uid': '4077-13751-0009', 'desc': 'Very long passage (33 words, 12.2s)'}, + {'uid': '2961-960-0000', 'desc': 'Very long passage (51 words, 27.2s)'}, + {'uid': '3729-6852-0006', 'desc': 'Very long passage (62 words, 23.7s)'}, +] + + +def prepare_samples(): + if not LIBRISPEECH_TAR.exists(): + print(f'ERROR: Download LibriSpeech test-clean first:') + print(f' curl -L -o {LIBRISPEECH_TAR} https://www.openslr.org/resources/12/test-clean.tar.gz') + sys.exit(1) + + if not LIBRISPEECH_DIR.exists(): + print('Extracting LibriSpeech test-clean...') + LIBRISPEECH_DIR.parent.parent.mkdir(parents=True, exist_ok=True) + subprocess.run(['tar', 'xzf', str(LIBRISPEECH_TAR), '-C', str(LIBRISPEECH_DIR.parent.parent)], check=True) + + AUDIO_DIR.mkdir(parents=True, exist_ok=True) + manifest = [] + + for i, pick in enumerate(SAMPLE_PICKS): + uid = pick['uid'] + parts = uid.split('-') + speaker, chapter = parts[0], parts[1] + flac_path = LIBRISPEECH_DIR / speaker / chapter / f'{uid}.flac' + + if not flac_path.exists(): + print(f'ERROR: FLAC not found: {flac_path}') + continue + + trans_file = flac_path.parent / f'{speaker}-{chapter}.trans.txt' + transcript = '' + for line in trans_file.read_text().strip().split('\n'): + line_parts = line.split(' ', 1) + if line_parts[0] == uid: + transcript = line_parts[1] + break + + wav_path = AUDIO_DIR / f'sample_{i + 1:02d}.wav' + subprocess.run( + ['ffmpeg', '-y', '-i', str(flac_path), '-ar', '16000', '-ac', '1', '-sample_fmt', 's16', str(wav_path)], + capture_output=True, + check=True, + ) + + result = subprocess.run( + ['ffprobe', '-v', 'quiet', '-show_entries', 'format=duration', '-of', 'csv=p=0', str(wav_path)], + capture_output=True, + text=True, + ) + duration = float(result.stdout.strip()) + + manifest.append( + { + 'id': f'sample_{i + 1:02d}', + 'uid': uid, + 'speaker': speaker, + 'text': transcript, + 'description': pick['desc'], + 'word_count': len(transcript.split()), + 'duration_s': round(duration, 2), + 'size_kb': round(wav_path.stat().st_size / 1024, 1), + } + ) + print(f' Prepared: sample_{i + 1:02d}.wav {duration:.1f}s {len(transcript.split())}w speaker={speaker}') + + with open(AUDIO_DIR / 'manifest.json', 'w') as f: + json.dump(manifest, f, indent=2) + print(f'\n{len(manifest)} samples prepared in {AUDIO_DIR}') + return manifest + + +def load_manifest() -> List[dict]: + manifest_path = AUDIO_DIR / 'manifest.json' + if not manifest_path.exists(): + print('Samples not prepared yet. Running preparation...') + return prepare_samples() + with open(manifest_path) as f: + return json.load(f) + + +def run_deepgram(audio_bytes: bytes) -> Tuple[str, float, int]: + t0 = time.monotonic() + result = deepgram_prerecorded_from_bytes(audio_bytes, sample_rate=16000, diarize=True) + elapsed = time.monotonic() - t0 + text = ' '.join(w.get('text', '') or w.get('word', '') for w in result).strip() + return text, elapsed, len(result) + + +def run_modulate(audio_bytes: bytes) -> Tuple[str, float, int]: + t0 = time.monotonic() + result = modulate_prerecorded_from_bytes(audio_bytes, sample_rate=16000, diarize=True) + elapsed = time.monotonic() - t0 + text = ' '.join(w.get('text', '') for w in result).strip() + return text, elapsed, len(result) + + +def main(): + if '--prepare' in sys.argv: + prepare_samples() + return + + RESULTS_DIR.mkdir(parents=True, exist_ok=True) + + dg_key = os.getenv('DEEPGRAM_API_KEY') + mod_key = os.getenv('MODULATE_API_KEY') + if not dg_key: + print('ERROR: DEEPGRAM_API_KEY not set') + sys.exit(1) + if not mod_key: + print('ERROR: MODULATE_API_KEY not set') + sys.exit(1) + + manifest = load_manifest() + print(f'\nBenchmark Suite 02 — Pre-recorded ({len(manifest)} samples, real human speech)') + print(f'Source: LibriSpeech test-clean (CC BY 4.0)\n') + + results = [] + for case in manifest: + wav_path = AUDIO_DIR / f"{case['id']}.wav" + audio_bytes = wav_path.read_bytes() + ref_norm = normalize_for_wer(case['text']) + + row = { + 'id': case['id'], + 'uid': case['uid'], + 'description': case['description'], + 'speaker': case['speaker'], + 'ref_words': case['word_count'], + 'duration_s': case['duration_s'], + 'audio_kb': case['size_kb'], + 'ref_text': case['text'], + } + + print(f" [{case['id']}] {case['description']} (speaker {case['speaker']})") + + try: + dg_text, dg_time, dg_segments = run_deepgram(audio_bytes) + dg_wer = compute_wer(ref_norm, normalize_for_wer(dg_text)) if dg_text else 1.0 + dg_punct = count_punctuation(dg_text) if dg_text else {'total': 0, 'detail': {}} + row.update( + { + 'dg_time': dg_time, + 'dg_words': len(dg_text.split()) if dg_text else 0, + 'dg_wer': dg_wer, + 'dg_text': dg_text, + 'dg_segments': dg_segments, + 'dg_punct': dg_punct['total'], + 'dg_punct_detail': dg_punct['detail'], + } + ) + print( + f" Deepgram: {dg_time:.2f}s WER={dg_wer:.2%} words={len(dg_text.split())} punct={dg_punct['total']}" + ) + except Exception as e: + print(f" Deepgram: ERROR - {e}") + row.update( + { + 'dg_time': -1, + 'dg_words': 0, + 'dg_wer': 1.0, + 'dg_text': f'ERROR: {e}', + 'dg_segments': 0, + 'dg_punct': 0, + 'dg_punct_detail': {}, + } + ) + + try: + mod_text, mod_time, mod_segments = run_modulate(audio_bytes) + mod_wer = compute_wer(ref_norm, normalize_for_wer(mod_text)) if mod_text else 1.0 + mod_punct = count_punctuation(mod_text) if mod_text else {'total': 0, 'detail': {}} + row.update( + { + 'mod_time': mod_time, + 'mod_words': len(mod_text.split()) if mod_text else 0, + 'mod_wer': mod_wer, + 'mod_text': mod_text, + 'mod_segments': mod_segments, + 'mod_punct': mod_punct['total'], + 'mod_punct_detail': mod_punct['detail'], + } + ) + print( + f" Modulate: {mod_time:.2f}s WER={mod_wer:.2%} words={len(mod_text.split())} punct={mod_punct['total']}" + ) + except Exception as e: + print(f" Modulate: ERROR - {e}") + row.update( + { + 'mod_time': -1, + 'mod_words': 0, + 'mod_wer': 1.0, + 'mod_text': f'ERROR: {e}', + 'mod_segments': 0, + 'mod_punct': 0, + 'mod_punct_detail': {}, + } + ) + + results.append(row) + + print('\n' + '=' * 110) + print('SUITE 02 — PRE-RECORDED BENCHMARK RESULTS (Real Human Speech — LibriSpeech test-clean)') + print('=' * 110) + + table_data = [] + for r in results: + table_data.append( + [ + r['id'], + r['ref_words'], + f"{r['duration_s']:.1f}s", + f"{r.get('dg_time', -1):.2f}s" if r.get('dg_time', -1) >= 0 else 'ERR', + f"{r.get('dg_wer', 1):.1%}", + r.get('dg_words', 0), + r.get('dg_punct', 0), + f"{r.get('mod_time', -1):.2f}s" if r.get('mod_time', -1) >= 0 else 'ERR', + f"{r.get('mod_wer', 1):.1%}", + r.get('mod_words', 0), + r.get('mod_punct', 0), + ] + ) + + print( + tabulate( + table_data, + headers=[ + 'Case', + 'Ref Words', + 'Duration', + 'DG Time', + 'DG WER', + 'DG Words', + 'DG Punct', + 'Mod Time', + 'Mod WER', + 'Mod Words', + 'Mod Punct', + ], + tablefmt='grid', + ) + ) + + valid_dg = [r for r in results if r.get('dg_time', -1) >= 0] + valid_mod = [r for r in results if r.get('mod_time', -1) >= 0] + + print('\nSUMMARY (WER computed after stripping punctuation):') + if valid_dg: + avg_dg_time = sum(r['dg_time'] for r in valid_dg) / len(valid_dg) + avg_dg_wer = sum(r['dg_wer'] for r in valid_dg) / len(valid_dg) + avg_dg_punct = sum(r.get('dg_punct', 0) for r in valid_dg) / len(valid_dg) + print( + f" Deepgram: avg_latency={avg_dg_time:.2f}s avg_WER={avg_dg_wer:.1%} " + f"avg_punct={avg_dg_punct:.1f} cases={len(valid_dg)}" + ) + if valid_mod: + avg_mod_time = sum(r['mod_time'] for r in valid_mod) / len(valid_mod) + avg_mod_wer = sum(r['mod_wer'] for r in valid_mod) / len(valid_mod) + avg_mod_punct = sum(r.get('mod_punct', 0) for r in valid_mod) / len(valid_mod) + print( + f" Modulate: avg_latency={avg_mod_time:.2f}s avg_WER={avg_mod_wer:.1%} " + f"avg_punct={avg_mod_punct:.1f} cases={len(valid_mod)}" + ) + + print('\nTRANSCRIPT COMPARISON:') + for r in results: + print(f"\n [{r['id']}] {r['description']}") + print(f" REF: {r['ref_text']}") + if r.get('dg_text', '').startswith('ERROR'): + print(f" DEEPGRAM: {r.get('dg_text', 'N/A')}") + else: + print( + f" DEEPGRAM: {r.get('dg_text', 'N/A')} (WER={r.get('dg_wer', 1):.1%}, punct={r.get('dg_punct', 0)})" + ) + if r.get('mod_text', '').startswith('ERROR'): + print(f" MODULATE: {r.get('mod_text', 'N/A')}") + else: + print( + f" MODULATE: {r.get('mod_text', 'N/A')} (WER={r.get('mod_wer', 1):.1%}, punct={r.get('mod_punct', 0)})" + ) + + output_path = RESULTS_DIR / 'suite02_prerecorded_benchmark.json' + with open(output_path, 'w') as f: + json.dump(results, f, indent=2) + print(f'\nDetailed results saved to: {output_path}') + + +if __name__ == '__main__': + main() diff --git a/backend/scripts/stt/o_benchmark_02_streaming.py b/backend/scripts/stt/o_benchmark_02_streaming.py new file mode 100644 index 0000000000..2e980f24ce --- /dev/null +++ b/backend/scripts/stt/o_benchmark_02_streaming.py @@ -0,0 +1,397 @@ +""" +Benchmark Suite 02: Deepgram vs Modulate — Streaming transcription. + +Uses real human speech from LibriSpeech test-clean dataset (12 samples, +12 speakers, 2-27s duration, 4-62 words). Ground truth transcripts for +accurate WER measurement. Streams at real-time pace (3200 bytes/100ms). + +Setup: + 1. Download LibriSpeech test-clean: + curl -L -o /tmp/test-clean.tar.gz https://www.openslr.org/resources/12/test-clean.tar.gz + 2. Prepare samples (shared with pre-recorded benchmark): + python scripts/stt/n_benchmark_02_prerecorded.py --prepare + +Usage: + cd backend && python scripts/stt/o_benchmark_02_streaming.py +""" + +import asyncio +import json +import os +import re +import sys +import time +from pathlib import Path +from typing import List + +from dotenv import load_dotenv + +load_dotenv(Path(__file__).resolve().parents[2] / '.env') + +from jiwer import wer as compute_wer +from tabulate import tabulate + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from utils.stt.streaming import process_audio_dg, process_audio_modulate + +PUNCT_RE = re.compile(r'[^\w\s]', re.UNICODE) + + +def normalize_for_wer(text: str) -> str: + return PUNCT_RE.sub('', text).lower().strip() + + +def count_punctuation(text: str) -> dict: + marks = re.findall(r'[^\w\s]', text) + return {'total': len(marks), 'detail': dict(sorted(((m, marks.count(m)) for m in set(marks)), key=lambda x: -x[1]))} + + +AUDIO_DIR = Path('/tmp/stt_benchmark_audio_02') +RESULTS_DIR = Path('/tmp/stt_benchmark_results') + +CHUNK_SIZE = 3200 +CHUNK_INTERVAL = 0.1 + + +def load_manifest() -> List[dict]: + manifest_path = AUDIO_DIR / 'manifest.json' + if not manifest_path.exists(): + print('ERROR: Samples not prepared. Run first:') + print(' python scripts/stt/n_benchmark_02_prerecorded.py --prepare') + sys.exit(1) + with open(manifest_path) as f: + return json.load(f) + + +def read_pcm_from_wav(wav_path: Path) -> bytes: + data = wav_path.read_bytes() + if data[:4] == b'RIFF': + return data[44:] + return data + + +async def stream_to_deepgram(audio_pcm: bytes, language: str) -> dict: + segments_received = [] + first_segment_time = [None] + connect_start = time.monotonic() + + def stream_transcript(segments): + if first_segment_time[0] is None: + first_segment_time[0] = time.monotonic() + segments_received.extend(segments) + + try: + socket = await asyncio.wait_for( + process_audio_dg(stream_transcript, language, 16000, 1, model='nova-3'), + timeout=15, + ) + except Exception as e: + return {'error': str(e), 'connect_time': -1} + + connect_time = time.monotonic() - connect_start + stream_start = time.monotonic() + + offset = 0 + while offset < len(audio_pcm): + chunk = audio_pcm[offset : offset + CHUNK_SIZE] + socket.send(chunk) + offset += CHUNK_SIZE + await asyncio.sleep(CHUNK_INTERVAL) + + socket.finish() + await asyncio.sleep(3) + + total_time = time.monotonic() - stream_start + text = ' '.join(s.get('text', '') for s in segments_received).strip() + first_seg_latency = (first_segment_time[0] - stream_start) if first_segment_time[0] else -1 + + return { + 'connect_time': connect_time, + 'first_segment_latency': first_seg_latency, + 'total_time': total_time, + 'segments': len(segments_received), + 'text': text, + 'words': len(text.split()) if text else 0, + } + + +async def stream_to_modulate(audio_pcm: bytes, language: str) -> dict: + segments_received = [] + first_segment_time = [None] + connect_start = time.monotonic() + + def stream_transcript(segments): + if first_segment_time[0] is None: + first_segment_time[0] = time.monotonic() + segments_received.extend(segments) + + try: + socket = await asyncio.wait_for( + process_audio_modulate(stream_transcript, 16000, language), + timeout=15, + ) + except Exception as e: + return {'error': str(e), 'connect_time': -1} + + connect_time = time.monotonic() - connect_start + stream_start = time.monotonic() + + offset = 0 + while offset < len(audio_pcm): + chunk = audio_pcm[offset : offset + CHUNK_SIZE] + socket.send(chunk) + offset += CHUNK_SIZE + await asyncio.sleep(CHUNK_INTERVAL) + + try: + await asyncio.wait_for(socket.drain_and_close(), timeout=30) + except (asyncio.TimeoutError, Exception): + pass + + total_time = time.monotonic() - stream_start + text = ' '.join(s.get('text', '') for s in segments_received).strip() + first_seg_latency = (first_segment_time[0] - stream_start) if first_segment_time[0] else -1 + + return { + 'connect_time': connect_time, + 'first_segment_latency': first_seg_latency, + 'total_time': total_time, + 'segments': len(segments_received), + 'text': text, + 'words': len(text.split()) if text else 0, + } + + +async def run_benchmark(): + RESULTS_DIR.mkdir(parents=True, exist_ok=True) + + dg_key = os.getenv('DEEPGRAM_API_KEY') + mod_key = os.getenv('MODULATE_API_KEY') + if not dg_key: + print('ERROR: DEEPGRAM_API_KEY not set') + sys.exit(1) + if not mod_key: + print('ERROR: MODULATE_API_KEY not set') + sys.exit(1) + + manifest = load_manifest() + print(f'\nBenchmark Suite 02 — Streaming ({len(manifest)} samples, real human speech)') + print(f'Source: LibriSpeech test-clean (CC BY 4.0)') + print(f'Streaming at real-time pace: {CHUNK_SIZE} bytes / {CHUNK_INTERVAL}s = 16kHz mono s16le\n') + + results = [] + for case in manifest: + wav_path = AUDIO_DIR / f"{case['id']}.wav" + audio_pcm = read_pcm_from_wav(wav_path) + ref_norm = normalize_for_wer(case['text']) + lang = 'en' + + row = { + 'id': case['id'], + 'uid': case['uid'], + 'description': case['description'], + 'speaker': case['speaker'], + 'ref_words': case['word_count'], + 'duration_s': case['duration_s'], + 'audio_kb': len(audio_pcm) / 1024, + 'ref_text': case['text'], + } + + print(f" [{case['id']}] {case['description']} (speaker {case['speaker']})") + + try: + dg_result = await stream_to_deepgram(audio_pcm, lang) + if 'error' in dg_result: + raise RuntimeError(dg_result['error']) + dg_wer = compute_wer(ref_norm, normalize_for_wer(dg_result['text'])) if dg_result['text'] else 1.0 + dg_punct = count_punctuation(dg_result['text']) if dg_result['text'] else {'total': 0, 'detail': {}} + row.update( + { + 'dg_connect': dg_result['connect_time'], + 'dg_first_seg': dg_result['first_segment_latency'], + 'dg_total': dg_result['total_time'], + 'dg_segments': dg_result['segments'], + 'dg_words': dg_result['words'], + 'dg_wer': dg_wer, + 'dg_text': dg_result['text'], + 'dg_punct': dg_punct['total'], + 'dg_punct_detail': dg_punct['detail'], + } + ) + print( + f" Deepgram: connect={dg_result['connect_time']:.2f}s " + f"first_seg={dg_result['first_segment_latency']:.2f}s " + f"total={dg_result['total_time']:.2f}s " + f"segs={dg_result['segments']} WER={dg_wer:.2%} punct={dg_punct['total']}" + ) + except Exception as e: + print(f" Deepgram: ERROR - {e}") + row.update( + { + 'dg_connect': -1, + 'dg_first_seg': -1, + 'dg_total': -1, + 'dg_segments': 0, + 'dg_words': 0, + 'dg_wer': 1.0, + 'dg_text': f'ERROR: {e}', + 'dg_punct': 0, + 'dg_punct_detail': {}, + } + ) + + try: + mod_result = await stream_to_modulate(audio_pcm, lang) + if 'error' in mod_result: + raise RuntimeError(mod_result['error']) + mod_wer = compute_wer(ref_norm, normalize_for_wer(mod_result['text'])) if mod_result['text'] else 1.0 + mod_punct = count_punctuation(mod_result['text']) if mod_result['text'] else {'total': 0, 'detail': {}} + row.update( + { + 'mod_connect': mod_result['connect_time'], + 'mod_first_seg': mod_result['first_segment_latency'], + 'mod_total': mod_result['total_time'], + 'mod_segments': mod_result['segments'], + 'mod_words': mod_result['words'], + 'mod_wer': mod_wer, + 'mod_text': mod_result['text'], + 'mod_punct': mod_punct['total'], + 'mod_punct_detail': mod_punct['detail'], + } + ) + print( + f" Modulate: connect={mod_result['connect_time']:.2f}s " + f"first_seg={mod_result['first_segment_latency']:.2f}s " + f"total={mod_result['total_time']:.2f}s " + f"segs={mod_result['segments']} WER={mod_wer:.2%} punct={mod_punct['total']}" + ) + except Exception as e: + print(f" Modulate: ERROR - {e}") + row.update( + { + 'mod_connect': -1, + 'mod_first_seg': -1, + 'mod_total': -1, + 'mod_segments': 0, + 'mod_words': 0, + 'mod_wer': 1.0, + 'mod_text': f'ERROR: {e}', + 'mod_punct': 0, + 'mod_punct_detail': {}, + } + ) + + results.append(row) + + print('\n' + '=' * 130) + print('SUITE 02 — STREAMING BENCHMARK RESULTS (Real Human Speech — LibriSpeech test-clean)') + print('=' * 130) + + def fmt_time(v): + return f"{v:.2f}s" if v >= 0 else 'ERR' + + table_data = [] + for r in results: + table_data.append( + [ + r['id'], + r['ref_words'], + f"{r['duration_s']:.1f}s", + fmt_time(r.get('dg_connect', -1)), + fmt_time(r.get('dg_first_seg', -1)), + fmt_time(r.get('dg_total', -1)), + r.get('dg_segments', 0), + f"{r.get('dg_wer', 1):.0%}", + r.get('dg_punct', 0), + fmt_time(r.get('mod_connect', -1)), + fmt_time(r.get('mod_first_seg', -1)), + fmt_time(r.get('mod_total', -1)), + r.get('mod_segments', 0), + f"{r.get('mod_wer', 1):.0%}", + r.get('mod_punct', 0), + ] + ) + + print( + tabulate( + table_data, + headers=[ + 'Case', + 'Words', + 'Duration', + 'DG Conn', + 'DG 1st Seg', + 'DG Total', + 'DG Segs', + 'DG WER', + 'DG Punct', + 'Mod Conn', + 'Mod 1st Seg', + 'Mod Total', + 'Mod Segs', + 'Mod WER', + 'Mod Punct', + ], + tablefmt='grid', + ) + ) + + valid_dg = [r for r in results if r.get('dg_total', -1) >= 0] + valid_mod = [r for r in results if r.get('mod_total', -1) >= 0] + + print('\nSUMMARY (WER computed after stripping punctuation):') + if valid_dg: + dg_first_segs = [r['dg_first_seg'] for r in valid_dg if r['dg_first_seg'] >= 0] + avg_dg_punct = sum(r.get('dg_punct', 0) for r in valid_dg) / len(valid_dg) + print( + f" Deepgram: " + f"avg_connect={sum(r['dg_connect'] for r in valid_dg) / len(valid_dg):.2f}s " + f"avg_first_seg={sum(dg_first_segs) / max(1, len(dg_first_segs)):.2f}s " + f"avg_total={sum(r['dg_total'] for r in valid_dg) / len(valid_dg):.2f}s " + f"avg_WER={sum(r['dg_wer'] for r in valid_dg) / len(valid_dg):.1%} " + f"avg_punct={avg_dg_punct:.1f} " + f"cases={len(valid_dg)}" + ) + if valid_mod: + mod_first_segs = [r['mod_first_seg'] for r in valid_mod if r['mod_first_seg'] >= 0] + avg_mod_punct = sum(r.get('mod_punct', 0) for r in valid_mod) / len(valid_mod) + print( + f" Modulate: " + f"avg_connect={sum(r['mod_connect'] for r in valid_mod) / len(valid_mod):.2f}s " + f"avg_first_seg={sum(mod_first_segs) / max(1, len(mod_first_segs)):.2f}s " + f"avg_total={sum(r['mod_total'] for r in valid_mod) / len(valid_mod):.2f}s " + f"avg_WER={sum(r['mod_wer'] for r in valid_mod) / len(valid_mod):.1%} " + f"avg_punct={avg_mod_punct:.1f} " + f"cases={len(valid_mod)}" + ) + + print('\nTRANSCRIPT COMPARISON:') + for r in results: + print(f"\n [{r['id']}] {r['description']}") + print(f" REF: {r.get('ref_text', 'N/A')}") + if r.get('dg_text', '').startswith('ERROR'): + print(f" DEEPGRAM: {r.get('dg_text', 'N/A')}") + else: + print( + f" DEEPGRAM: {r.get('dg_text', 'N/A')} (WER={r.get('dg_wer', 1):.1%}, punct={r.get('dg_punct', 0)})" + ) + if r.get('mod_text', '').startswith('ERROR'): + print(f" MODULATE: {r.get('mod_text', 'N/A')}") + else: + print( + f" MODULATE: {r.get('mod_text', 'N/A')} (WER={r.get('mod_wer', 1):.1%}, punct={r.get('mod_punct', 0)})" + ) + + output_path = RESULTS_DIR / 'suite02_streaming_benchmark.json' + with open(output_path, 'w') as f: + json.dump(results, f, indent=2) + print(f'\nDetailed results saved to: {output_path}') + + +def main(): + asyncio.run(run_benchmark()) + + +if __name__ == '__main__': + main() diff --git a/backend/scripts/stt/p_listen_api_walkthrough.py b/backend/scripts/stt/p_listen_api_walkthrough.py new file mode 100644 index 0000000000..49552e4c0e --- /dev/null +++ b/backend/scripts/stt/p_listen_api_walkthrough.py @@ -0,0 +1,658 @@ +""" +Listen API Walkthrough — L2 Integration Test for /v4/listen WebSocket. + +Streams 5+ minutes of real LibriSpeech audio through the local backend's +/v4/listen WebSocket endpoint, testing both Deepgram and Modulate STT providers. +Captures transcription results, service logs, timing, and identifies flaws. + +Prerequisites: + 1. LibriSpeech test-clean extracted: + curl -L -o /tmp/test-clean.tar.gz https://www.openslr.org/resources/12/test-clean.tar.gz + cd /tmp && mkdir -p librispeech && tar xzf test-clean.tar.gz -C librispeech + 2. Backend running via beast omi dev: + beast omi dev start backend + 3. Environment: LOCAL_DEVELOPMENT=true (beast omi dev default) + +Usage: + cd backend && python3 scripts/stt/p_listen_api_walkthrough.py + cd backend && python3 scripts/stt/p_listen_api_walkthrough.py --provider deepgram + cd backend && python3 scripts/stt/p_listen_api_walkthrough.py --provider modulate +""" + +import argparse +import asyncio +import json +import re +import socket +import subprocess +import sys +import time +from datetime import datetime +from pathlib import Path +from typing import List + +import websockets + +BACKEND_HOST = "localhost" +BACKEND_PORT = 8700 +LISTEN_URL = f"ws://{BACKEND_HOST}:{BACKEND_PORT}/v4/listen" +DEV_AUTH_HEADER = {"authorization": "Bearer dev-token"} + +LIBRISPEECH_DIR = Path('/tmp/librispeech/LibriSpeech/test-clean') +RESULTS_DIR = Path('/tmp/stt_listen_walkthrough') + +CHUNK_SIZE = 3200 +CHUNK_INTERVAL_S = 0.1 +TARGET_DURATION_S = 300 + +PUNCT_RE = re.compile(r'[^\w\s]', re.UNICODE) + + +def normalize_for_wer(text: str) -> str: + return PUNCT_RE.sub('', text).lower().strip() + + +def build_audio_playlist(target_seconds: float = TARGET_DURATION_S) -> List[dict]: + if not LIBRISPEECH_DIR.exists(): + print(f'ERROR: LibriSpeech not found at {LIBRISPEECH_DIR}') + print(' curl -L -o /tmp/test-clean.tar.gz https://www.openslr.org/resources/12/test-clean.tar.gz') + print(' cd /tmp && mkdir -p librispeech && tar xzf test-clean.tar.gz -C librispeech') + sys.exit(1) + + playlist = [] + total = 0.0 + speakers = sorted(LIBRISPEECH_DIR.iterdir()) + + for speaker_dir in speakers: + if total >= target_seconds: + break + if not speaker_dir.is_dir(): + continue + for chapter_dir in sorted(speaker_dir.iterdir()): + if total >= target_seconds: + break + if not chapter_dir.is_dir(): + continue + + trans_files = list(chapter_dir.glob('*.trans.txt')) + transcripts = {} + for tf in trans_files: + for line in tf.read_text().strip().split('\n'): + parts = line.split(' ', 1) + if len(parts) == 2: + transcripts[parts[0]] = parts[1] + + for flac in sorted(chapter_dir.glob('*.flac')): + if total >= target_seconds: + break + uid = flac.stem + text = transcripts.get(uid, '') + r = subprocess.run( + ['ffprobe', '-v', 'quiet', '-show_entries', 'format=duration', '-of', 'csv=p=0', str(flac)], + capture_output=True, + text=True, + ) + dur = float(r.stdout.strip()) if r.stdout.strip() else 0 + if dur < 1: + continue + playlist.append( + { + 'uid': uid, + 'flac': str(flac), + 'speaker': speaker_dir.name, + 'text': text, + 'duration_s': round(dur, 2), + 'word_count': len(text.split()) if text else 0, + } + ) + total += dur + + return playlist + + +def convert_to_pcm16(flac_path: str) -> bytes: + r = subprocess.run( + ['ffmpeg', '-y', '-i', flac_path, '-f', 's16le', '-ar', '16000', '-ac', '1', 'pipe:1'], + capture_output=True, + ) + return r.stdout + + +def capture_service_logs(service: str, output_path: Path, duration: int = 5): + try: + r = subprocess.run( + ['timeout', str(duration), 'beast', 'omi', 'dev', 'logs', service], + capture_output=True, + text=True, + timeout=duration + 5, + ) + output_path.write_text(r.stdout + r.stderr) + except Exception: + pass + + +async def run_listen_test( + provider: str, + playlist: List[dict], + stt_service_models: str, +) -> dict: + results = { + 'provider': provider, + 'stt_service_models': stt_service_models, + 'start_time': datetime.now(tz=None).isoformat(), + 'samples': [], + 'events': [], + 'flaws': [], + 'stats': {}, + } + + total_audio_s = sum(s['duration_s'] for s in playlist) + total_words = sum(s['word_count'] for s in playlist) + print(f'\n{"=" * 80}') + print(f'LISTEN API WALKTHROUGH — {provider.upper()}') + print(f'{"=" * 80}') + print( + f'Audio: {total_audio_s:.1f}s ({total_audio_s / 60:.1f} min), {len(playlist)} utterances, {total_words} words' + ) + print(f'STT_SERVICE_MODELS={stt_service_models}') + print(f'Endpoint: {LISTEN_URL}') + print() + + params = { + 'language': 'en', + 'sample_rate': '16000', + 'codec': 'pcm16', + 'channels': '1', + 'include_speech_profile': 'false', + 'conversation_timeout': '30', + } + url = f'{LISTEN_URL}?{"&".join(f"{k}={v}" for k, v in params.items())}' + + segments_received = [] + segments_by_id = {} + events_received = [] + first_segment_time = [None] + ready_time = [None] + connect_start = time.monotonic() + + try: + ws = await asyncio.wait_for( + websockets.connect( + url, + additional_headers=DEV_AUTH_HEADER, + ping_timeout=None, + ping_interval=None, + max_size=None, + close_timeout=10, + ), + timeout=15, + ) + except Exception as e: + err = f'Connection failed: {e}' + print(f' ERROR: {err}') + results['flaws'].append({'type': 'connection_failure', 'detail': err}) + results['end_time'] = datetime.now(tz=None).isoformat() + return results + + connect_time = time.monotonic() - connect_start + print(f' Connected in {connect_time:.2f}s') + + msg_count = [0] + + async def recv_messages(): + nonlocal segments_received, events_received + try: + async for raw in ws: + msg_count[0] += 1 + if isinstance(raw, bytes): + ts = time.monotonic() - connect_start + if msg_count[0] <= 5: + print(f' [{ts:.1f}s] binary msg #{msg_count[0]}: {len(raw)} bytes') + continue + if raw == 'ping': + continue + + ts = time.monotonic() - connect_start + + try: + msg = json.loads(raw) + except json.JSONDecodeError: + print(f' [{ts:.1f}s] non-JSON msg: {raw[:200]}') + continue + + try: + # Server sends segments as bare JSON array [...] or as {"segments": [...]} + if isinstance(msg, list): + segs = msg + elif isinstance(msg, dict): + msg['_recv_ts'] = round(ts, 3) + + if msg.get('status') == 'ready': + ready_time[0] = ts + events_received.append({'type': 'ready', 'ts': round(ts, 3)}) + print(f' [{ts:.1f}s] Server ready') + continue + + if 'segments' in msg: + segs = msg['segments'] + else: + msg_type = str(msg.get('type', msg.get('status', 'unknown'))) + if msg_count[0] <= 30 or msg_type not in ('ping', 'pong'): + print(f' [{ts:.1f}s] event: {msg_type} — {json.dumps(msg)[:300]}') + events_received.append(msg) + continue + else: + continue + + if not isinstance(segs, list): + continue + seg_texts = [] + for seg in segs: + if not isinstance(seg, dict): + continue + text = seg.get('text', '').strip() + if text: + if first_segment_time[0] is None: + first_segment_time[0] = ts + entry = { + 'id': seg.get('id', ''), + 'text': text, + 'speaker': seg.get('speaker', ''), + 'start': seg.get('start', 0), + 'end': seg.get('end', 0), + 'recv_ts': round(ts, 3), + 'is_user': seg.get('is_user', False), + } + segments_received.append(entry) + seg_id = seg.get('id', '') + if seg_id: + segments_by_id[seg_id] = entry + seg_texts.append(text[:60]) + if seg_texts: + print(f' [{ts:.1f}s] SEGMENT: {" | ".join(seg_texts)}') + + except Exception as e: + print(f' [{ts:.1f}s] MSG PARSE ERROR: {e} — raw: {raw[:300]}') + continue + + except websockets.exceptions.ConnectionClosed as e: + events_received.append({'type': 'ws_closed', 'code': e.code, 'reason': str(e.reason)}) + print(f' WS CLOSED: code={e.code} reason={e.reason}') + except Exception as e: + events_received.append({'type': 'recv_error', 'detail': str(e)}) + print(f' RECV ERROR: {e}') + + recv_task = asyncio.create_task(recv_messages()) + + # Wait for ready + deadline = time.monotonic() + 30 + while ready_time[0] is None and time.monotonic() < deadline: + if recv_task.done(): + break + await asyncio.sleep(0.1) + + if ready_time[0] is None: + results['flaws'].append({'type': 'no_ready_signal', 'detail': 'Server did not send ready status within 30s'}) + print(' WARNING: No ready signal received, proceeding anyway...') + + # Stream audio + stream_start = time.monotonic() + total_bytes_sent = 0 + samples_sent = 0 + + send_failed = False + for i, sample in enumerate(playlist): + if send_failed: + break + pcm = convert_to_pcm16(sample['flac']) + if not pcm: + continue + + offset = 0 + while offset < len(pcm): + chunk = pcm[offset : offset + CHUNK_SIZE] + try: + await ws.send(chunk) + except Exception as e: + err = f'Send failed at sample {i} offset {offset}: {e}' + results['flaws'].append({'type': 'send_failure', 'detail': err}) + print(f' ERROR: {err}') + send_failed = True + break + offset += CHUNK_SIZE + total_bytes_sent += len(chunk) + await asyncio.sleep(CHUNK_INTERVAL_S) + + if send_failed: + break + + samples_sent += 1 + elapsed = time.monotonic() - connect_start + words_so_far = sum(len(s['text'].split()) for s in segments_received) + + if (i + 1) % 10 == 0 or i == len(playlist) - 1: + print( + f' [{elapsed:.0f}s] Sent {samples_sent}/{len(playlist)} samples, ' + f'{total_bytes_sent / 1024:.0f}KB, received {len(segments_received)} segments ({words_so_far} words)' + ) + + # Insert 1s silence between utterances + silence = b'\x00' * (16000 * 2) + try: + await ws.send(silence) + except Exception: + send_failed = True + break + total_bytes_sent += len(silence) + await asyncio.sleep(1.0) + + stream_duration = time.monotonic() - stream_start + + # Wait for trailing transcription — Modulate without partial_results + # batches utterances and can take 30-45s after end-of-audio + print(f' Waiting for trailing transcription results...') + await asyncio.sleep(50) + + # Close + try: + await ws.close() + except Exception: + pass + + await asyncio.sleep(2) + recv_task.cancel() + try: + await recv_task + except (asyncio.CancelledError, Exception): + pass + + total_time = time.monotonic() - connect_start + + # Compute stats — use deduplicated segments (final version of each ID) + final_segments = list(segments_by_id.values()) if segments_by_id else segments_received + all_received_text = ' '.join(s['text'] for s in final_segments) + all_ref_text = ' '.join(s['text'] for s in playlist[:samples_sent]) + received_words = len(all_received_text.split()) if all_received_text else 0 + + wer = None + try: + from jiwer import wer as compute_wer + + ref_norm = normalize_for_wer(all_ref_text) + hyp_norm = normalize_for_wer(all_received_text) + if ref_norm and hyp_norm: + wer = compute_wer(ref_norm, hyp_norm) + except ImportError: + pass + + punct_marks = re.findall(r'[^\w\s]', all_received_text) + + stats = { + 'connect_time_s': round(connect_time, 3), + 'ready_time_s': round(ready_time[0], 3) if ready_time[0] else None, + 'first_segment_time_s': round(first_segment_time[0], 3) if first_segment_time[0] else None, + 'stream_duration_s': round(stream_duration, 1), + 'total_time_s': round(total_time, 1), + 'samples_sent': samples_sent, + 'total_audio_s': round(total_audio_s, 1), + 'total_bytes_sent': total_bytes_sent, + 'segment_updates': len(segments_received), + 'segments_final': len(final_segments), + 'words_received': received_words, + 'ref_words': total_words, + 'wer': round(wer, 4) if wer is not None else None, + 'punctuation_marks': len(punct_marks), + 'unique_speakers': len(set(s['speaker'] for s in final_segments)), + 'events_count': len(events_received), + } + results['stats'] = stats + results['events'] = events_received[:100] + results['end_time'] = datetime.now(tz=None).isoformat() + results['final_segments'] = final_segments + results['full_transcript'] = all_received_text + results['full_reference'] = all_ref_text + + # Flaw detection + if not final_segments: + results['flaws'].append({'type': 'no_transcription', 'detail': 'No segments received from server'}) + + if ready_time[0] and first_segment_time[0]: + latency = first_segment_time[0] - ready_time[0] + if latency > 15: + results['flaws'].append( + {'type': 'high_first_segment_latency', 'detail': f'First segment took {latency:.1f}s after ready'} + ) + + if received_words > 0 and total_words > 0: + word_ratio = received_words / total_words + if word_ratio < 0.3: + results['flaws'].append( + { + 'type': 'low_word_capture', + 'detail': f'Only captured {word_ratio:.0%} of expected words ({received_words}/{total_words})', + } + ) + + if wer is not None and wer > 0.5: + results['flaws'].append({'type': 'high_wer', 'detail': f'WER={wer:.1%} — more than 50% of words incorrect'}) + + if segments_received: + last_seg_ts = segments_received[-1]['recv_ts'] + if total_time - last_seg_ts > 30: + results['flaws'].append( + { + 'type': 'stale_transcription', + 'detail': f'Last segment at {last_seg_ts:.0f}s but test ran {total_time:.0f}s — {total_time - last_seg_ts:.0f}s gap', + } + ) + + # Print summary + print(f'\n RESULTS — {provider.upper()}') + print(f' {"─" * 60}') + print(f' Connect time: {stats["connect_time_s"]:.2f}s') + print(f' Ready time: {stats["ready_time_s"]:.2f}s' if stats['ready_time_s'] else ' Ready time: N/A') + print( + f' First segment: {stats["first_segment_time_s"]:.2f}s' + if stats['first_segment_time_s'] + else ' First segment: N/A' + ) + print(f' Stream duration: {stats["stream_duration_s"]:.1f}s') + print(f' Total time: {stats["total_time_s"]:.1f}s') + print(f' Samples sent: {stats["samples_sent"]}') + print(f' Audio duration: {stats["total_audio_s"]:.1f}s ({stats["total_audio_s"] / 60:.1f} min)') + print(f' Bytes sent: {stats["total_bytes_sent"] / 1024:.0f} KB') + print(f' Segment updates: {stats["segment_updates"]}') + print(f' Final segments: {stats["segments_final"]}') + print(f' Words received: {stats["words_received"]} (ref: {stats["ref_words"]})') + if stats['wer'] is not None: + print(f' WER: {stats["wer"]:.1%}') + print(f' Punctuation: {stats["punctuation_marks"]}') + print(f' Unique speakers: {stats["unique_speakers"]}') + + if results['flaws']: + print(f'\n FLAWS DETECTED:') + for flaw in results['flaws']: + print(f' [{flaw["type"]}] {flaw["detail"]}') + else: + print(f'\n No flaws detected.') + + # Save per-sample transcript comparison + results['transcript_sample'] = [] + seg_idx = 0 + for sample in playlist[: min(samples_sent, 20)]: + matching = [] + while seg_idx < len(segments_received): + s = segments_received[seg_idx] + matching.append(s['text']) + seg_idx += 1 + if len(' '.join(matching).split()) >= sample['word_count']: + break + results['transcript_sample'].append( + { + 'uid': sample['uid'], + 'ref': sample['text'], + 'hyp': ' '.join(matching), + 'ref_words': sample['word_count'], + } + ) + + return results + + +async def main(): + global BACKEND_PORT, LISTEN_URL + + parser = argparse.ArgumentParser(description='Listen API Walkthrough — L2 Integration Test') + parser.add_argument('--provider', choices=['deepgram', 'modulate', 'both'], default='both') + parser.add_argument('--duration', type=int, default=TARGET_DURATION_S, help='Target audio duration in seconds') + parser.add_argument('--port', type=int, default=BACKEND_PORT, help='Backend port') + parser.add_argument( + '--skip-restart', action='store_true', help='Skip backend restart, use currently running backend' + ) + args = parser.parse_args() + + BACKEND_PORT = args.port + LISTEN_URL = f'ws://{BACKEND_HOST}:{BACKEND_PORT}/v4/listen' + + RESULTS_DIR.mkdir(parents=True, exist_ok=True) + + try: + sock = socket.create_connection((BACKEND_HOST, BACKEND_PORT), timeout=3) + sock.close() + except (socket.timeout, ConnectionRefusedError, OSError): + print(f'ERROR: Backend not running on {BACKEND_HOST}:{BACKEND_PORT}') + print(' Run: beast omi dev start backend') + sys.exit(1) + + print(f'Building audio playlist (target: {args.duration}s / {args.duration / 60:.1f} min)...') + playlist = build_audio_playlist(args.duration) + total_s = sum(s['duration_s'] for s in playlist) + print(f' {len(playlist)} utterances, {total_s:.1f}s ({total_s / 60:.1f} min)') + + providers = [] + if args.provider in ('deepgram', 'both'): + providers.append(('deepgram', 'dg-nova-3')) + if args.provider in ('modulate', 'both'): + providers.append(('modulate', 'modulate-velma-2')) + + all_results = [] + + for provider_name, stt_models in providers: + print(f'\n--- Configuring backend for {provider_name} (STT_SERVICE_MODELS={stt_models}) ---') + + if args.skip_restart: + print(f' --skip-restart: using currently running backend on port {BACKEND_PORT}') + else: + subprocess.run(['beast', 'omi', 'dev', 'stop', 'backend'], capture_output=True) + await asyncio.sleep(2) + + env_override = f'STT_SERVICE_MODELS={stt_models}' + print(f' Starting backend with {env_override}...') + + subprocess.Popen( + ['bash', '-c', f'{env_override} beast omi dev start backend'], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ).wait(timeout=30) + + for attempt in range(30): + try: + sock = socket.create_connection((BACKEND_HOST, BACKEND_PORT), timeout=2) + sock.close() + break + except (socket.timeout, ConnectionRefusedError, OSError): + await asyncio.sleep(1) + else: + print(f' ERROR: Backend did not start within 30s') + continue + + await asyncio.sleep(3) + + print(f' Backend ready on port {BACKEND_PORT}') + + pre_log_path = RESULTS_DIR / f'{provider_name}_pre_logs.txt' + capture_service_logs('backend', pre_log_path, duration=3) + + result = await run_listen_test(provider_name, playlist, stt_models) + all_results.append(result) + + post_log_path = RESULTS_DIR / f'{provider_name}_post_logs.txt' + capture_service_logs('backend', post_log_path, duration=5) + + result_path = RESULTS_DIR / f'{provider_name}_result.json' + with open(result_path, 'w') as f: + json.dump(result, f, indent=2, default=str) + print(f' Results saved to: {result_path}') + + if len(all_results) == 2: + print(f'\n{"=" * 80}') + print('COMPARISON SUMMARY') + print(f'{"=" * 80}') + print(f'{"Metric":<30} {"Deepgram":<25} {"Modulate":<25}') + print(f'{"─" * 80}') + + dg, mod = all_results[0], all_results[1] + dg_s, mod_s = dg['stats'], mod['stats'] + + metrics = [ + ('Connect time', f'{dg_s["connect_time_s"]:.2f}s', f'{mod_s["connect_time_s"]:.2f}s'), + ( + 'Ready time', + f'{dg_s["ready_time_s"]:.2f}s' if dg_s['ready_time_s'] else 'N/A', + f'{mod_s["ready_time_s"]:.2f}s' if mod_s['ready_time_s'] else 'N/A', + ), + ( + 'First segment', + f'{dg_s["first_segment_time_s"]:.2f}s' if dg_s['first_segment_time_s'] else 'N/A', + f'{mod_s["first_segment_time_s"]:.2f}s' if mod_s['first_segment_time_s'] else 'N/A', + ), + ('Segment updates', str(dg_s['segment_updates']), str(mod_s['segment_updates'])), + ('Final segments', str(dg_s['segments_final']), str(mod_s['segments_final'])), + ( + 'Words received', + f'{dg_s["words_received"]} / {dg_s["ref_words"]}', + f'{mod_s["words_received"]} / {mod_s["ref_words"]}', + ), + ( + 'WER', + f'{dg_s["wer"]:.1%}' if dg_s['wer'] is not None else 'N/A', + f'{mod_s["wer"]:.1%}' if mod_s['wer'] is not None else 'N/A', + ), + ('Punctuation marks', str(dg_s['punctuation_marks']), str(mod_s['punctuation_marks'])), + ('Unique speakers', str(dg_s['unique_speakers']), str(mod_s['unique_speakers'])), + ('Flaws', str(len(dg['flaws'])), str(len(mod['flaws']))), + ] + + for label, dg_val, mod_val in metrics: + print(f'{label:<30} {dg_val:<25} {mod_val:<25}') + + print(f'\nTRANSCRIPT COMPARISON (first 5 samples):') + for i in range(min(5, len(dg.get('transcript_sample', [])))): + dg_t = dg['transcript_sample'][i] if i < len(dg.get('transcript_sample', [])) else {} + mod_t = mod['transcript_sample'][i] if i < len(mod.get('transcript_sample', [])) else {} + print(f'\n [{dg_t.get("uid", "?")}]') + print(f' REF: {dg_t.get("ref", "N/A")}') + print(f' DEEPGRAM: {dg_t.get("hyp", "N/A")}') + print(f' MODULATE: {mod_t.get("hyp", "N/A")}') + + print(f'\nALL FLAWS FOUND:') + all_flaws = [] + for r in all_results: + for f in r['flaws']: + f['provider'] = r['provider'] + all_flaws.append(f) + if all_flaws: + for f in all_flaws: + print(f' [{f["provider"]}] [{f["type"]}] {f["detail"]}') + else: + print(' None detected.') + + combined_path = RESULTS_DIR / 'listen_walkthrough_combined.json' + with open(combined_path, 'w') as f: + json.dump(all_results, f, indent=2, default=str) + print(f'\nCombined results saved to: {combined_path}') + print(f'Log files in: {RESULTS_DIR}/') + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/backend/scripts/stt/q_modulate_debug.py b/backend/scripts/stt/q_modulate_debug.py new file mode 100644 index 0000000000..d0ffe74ca5 --- /dev/null +++ b/backend/scripts/stt/q_modulate_debug.py @@ -0,0 +1,417 @@ +""" +Debug: stream the same 5-min LibriSpeech audio directly to Modulate API +and log every raw message. Compare raw Modulate output vs what our +SafeModulateSocket produces to find where words get lost. + +Usage: + python3 scripts/stt/q_modulate_debug.py +""" + +import asyncio +import json +import os +import struct +import subprocess +import sys +import time +import urllib.parse +from pathlib import Path + +import websockets + +MODULATE_API_KEY = os.getenv('MODULATE_API_KEY', '') +LIBRISPEECH_DIR = Path('/tmp/librispeech/LibriSpeech/test-clean') +CHUNK_SIZE = 3200 +CHUNK_INTERVAL_S = 0.1 +TARGET_DURATION = 300 + + +def build_playlist(target_s): + playlist = [] + total_s = 0 + ref_path = LIBRISPEECH_DIR + for reader_dir in sorted(ref_path.iterdir()): + if not reader_dir.is_dir(): + continue + for chapter_dir in sorted(reader_dir.iterdir()): + if not chapter_dir.is_dir(): + continue + trans_file = list(chapter_dir.glob('*.trans.txt')) + if not trans_file: + continue + transcripts = {} + for line in trans_file[0].read_text().strip().split('\n'): + parts = line.strip().split(' ', 1) + if len(parts) == 2: + transcripts[parts[0]] = parts[1] + for flac in sorted(chapter_dir.glob('*.flac')): + uid = flac.stem + ref = transcripts.get(uid, '') + playlist.append({'flac': str(flac), 'ref': ref, 'uid': uid}) + result = subprocess.run( + [ + 'ffprobe', + '-v', + 'error', + '-show_entries', + 'format=duration', + '-of', + 'default=noprint_wrappers=1:nokey=1', + str(flac), + ], + capture_output=True, + text=True, + ) + dur = float(result.stdout.strip()) if result.returncode == 0 else 5.0 + total_s += dur + if total_s >= target_s: + return playlist, total_s + return playlist, total_s + + +def convert_to_pcm16(flac_path): + result = subprocess.run( + ['ffmpeg', '-y', '-i', flac_path, '-f', 's16le', '-ar', '16000', '-ac', '1', 'pipe:1'], + capture_output=True, + ) + return result.stdout if result.returncode == 0 else None + + +def compute_wer(ref, hyp): + ref_words = ref.upper().split() + hyp_words = hyp.upper().split() + if not ref_words: + return 0.0 if not hyp_words else 1.0 + d = [[0] * (len(hyp_words) + 1) for _ in range(len(ref_words) + 1)] + for i in range(len(ref_words) + 1): + d[i][0] = i + for j in range(len(hyp_words) + 1): + d[0][j] = j + for i in range(1, len(ref_words) + 1): + for j in range(1, len(hyp_words) + 1): + if ref_words[i - 1] == hyp_words[j - 1]: + d[i][j] = d[i - 1][j - 1] + else: + d[i][j] = 1 + min(d[i - 1][j], d[i][j - 1], d[i - 1][j - 1]) + return d[len(ref_words)][len(hyp_words)] / len(ref_words) + + +async def main(): + playlist, total_s = build_playlist(TARGET_DURATION) + ref_text = ' '.join(s['ref'] for s in playlist) + ref_words = ref_text.split() + print(f'Audio: {total_s:.1f}s, {len(playlist)} utterances, {len(ref_words)} ref words\n') + + # --- Test 1: Direct to Modulate API (like benchmark, but one long connection) --- + print('=' * 70) + print('TEST 1: Direct to Modulate API — single long connection, all audio') + print('=' * 70) + + params = { + 'api_key': MODULATE_API_KEY, + 'speaker_diarization': 'true', + 'partial_results': 'true', + 'sample_rate': '16000', + 'audio_format': 's16le', + 'num_channels': '1', + 'language': 'en', + } + uri = f'wss://modulate-developer-apis.com/api/velma-2-stt-streaming?{urllib.parse.urlencode(params)}' + + ws = await websockets.connect(uri, ping_timeout=30, ping_interval=10, max_size=None) + print(f'Connected to Modulate API') + + raw_messages = [] + all_utterances = [] + all_partials = [] + recv_done = asyncio.Event() + + async def recv_loop(): + try: + async for raw_msg in ws: + msg = json.loads(raw_msg) + msg['_recv_ts'] = time.monotonic() - send_start + raw_messages.append(msg) + msg_type = msg.get('type', '') + if msg_type == 'partial_utterance': + pu = msg.get('partial_utterance', msg) + all_partials.append(pu) + elif msg_type == 'utterance': + utt = msg.get('utterance', msg) + all_utterances.append(utt) + print(f' [{msg["_recv_ts"]:.1f}s] UTTERANCE: {utt.get("text", "")[:80]}') + elif msg_type == 'done': + print(f' [{msg["_recv_ts"]:.1f}s] DONE: duration_ms={msg.get("duration_ms")}') + elif msg_type == 'error': + print(f' [{msg["_recv_ts"]:.1f}s] ERROR: {msg}') + except websockets.exceptions.ConnectionClosed as e: + print(f' WS closed: {e}') + finally: + recv_done.set() + + send_start = time.monotonic() + recv_task = asyncio.create_task(recv_loop()) + + total_bytes = 0 + for i, sample in enumerate(playlist): + pcm = convert_to_pcm16(sample['flac']) + if not pcm: + continue + offset = 0 + while offset < len(pcm): + chunk = pcm[offset : offset + CHUNK_SIZE] + try: + await ws.send(chunk) + except Exception as e: + print(f' Send error at sample {i}: {e}') + break + offset += CHUNK_SIZE + total_bytes += len(chunk) + await asyncio.sleep(CHUNK_INTERVAL_S) + + if (i + 1) % 10 == 0: + elapsed = time.monotonic() - send_start + print(f' [{elapsed:.0f}s] Sent {i + 1}/{len(playlist)} samples, {total_bytes / 1024:.0f}KB') + + # Signal end of audio + try: + await ws.send('') + except Exception: + pass + + print(f'\nAll audio sent ({total_bytes / 1024:.0f}KB). Waiting for final results...') + try: + await asyncio.wait_for(recv_done.wait(), timeout=60) + except asyncio.TimeoutError: + print(' Timed out waiting for recv loop') + recv_task.cancel() + try: + await ws.close() + except Exception: + pass + + # Analyze raw results + print(f'\n--- RAW MODULATE RESULTS ---') + print(f'Total messages: {len(raw_messages)}') + print(f'Utterances (final): {len(all_utterances)}') + print(f'Partial utterances: {len(all_partials)}') + + # Build transcript from final utterances only + utterance_text = ' '.join(u.get('text', '') for u in all_utterances).strip() + utterance_words = utterance_text.split() + print(f'Words from utterances: {len(utterance_words)}') + + # Build transcript from partials (simulating confirmed-word delta) + partial_words_emitted = 0 + delta_words = [] + for pu in all_partials: + text = pu.get('text', '').strip() + if not text: + continue + words = text.split() + confirmed_end = len(words) - 1 + if confirmed_end <= partial_words_emitted: + continue + delta = words[partial_words_emitted:confirmed_end] + delta_words.extend(delta) + partial_words_emitted = confirmed_end + + # Also handle final utterances with the delta approach + for utt in all_utterances: + text = utt.get('text', '').strip() + if not text: + continue + words = text.split() + if partial_words_emitted > 0: + remaining = words[partial_words_emitted:] + partial_words_emitted = 0 + delta_words.extend(remaining) + else: + delta_words.extend(words) + + delta_text = ' '.join(delta_words) + print(f'Words from delta approach: {len(delta_words)}') + + # WER + import re + + def strip_punct(t): + return re.sub(r'[^\w\s]', '', t) + + ref_clean = strip_punct(ref_text) + utt_wer = compute_wer(ref_clean, strip_punct(utterance_text)) + delta_wer = compute_wer(ref_clean, strip_punct(delta_text)) + print(f'\nWER (utterances only): {utt_wer * 100:.1f}%') + print(f'WER (delta approach): {delta_wer * 100:.1f}%') + + # Show first few utterances + print(f'\n--- FIRST 5 UTTERANCES ---') + for i, u in enumerate(all_utterances[:5]): + t = u.get('text', '') + print(f' [{i}] start={u.get("start_ms", 0)}ms dur={u.get("duration_ms", 0)}ms words={len(t.split())}') + print(f' {t[:120]}') + + # Show last few utterances + print(f'\n--- LAST 5 UTTERANCES ---') + for i, u in enumerate(all_utterances[-5:]): + t = u.get('text', '') + print( + f' [{len(all_utterances) - 5 + i}] start={u.get("start_ms", 0)}ms dur={u.get("duration_ms", 0)}ms words={len(t.split())}' + ) + print(f' {t[:120]}') + + # Check partial_words_emitted drift + print(f'\n--- PARTIAL WORD COUNTER ANALYSIS ---') + max_counter = 0 + resets = 0 + prev_counter = 0 + for pu in all_partials: + words = pu.get('text', '').split() + if len(words) - 1 < prev_counter and prev_counter > 5: + resets += 1 + if len(words) > max_counter: + max_counter = len(words) + prev_counter = len(words) - 1 + + print(f'Max partial word count in a single partial: {max_counter}') + print(f'Times counter appeared to reset (new partial shorter than prev): {resets}') + + # Transcript samples + print(f'\n--- UTTERANCE TRANSCRIPT (first 300 chars) ---') + print(utterance_text[:300]) + print(f'\n--- DELTA TRANSCRIPT (first 300 chars) ---') + print(delta_text[:300]) + print(f'\n--- REFERENCE (first 300 chars) ---') + print(ref_text[:300]) + + # Save full results + # Detailed delta trace — show where words get lost/garbled + print(f'\n--- DELTA APPROACH DETAILED TRACE (first 30 events) ---') + pwe = 0 + trace_events = [] + for idx, pu in enumerate(all_partials): + text = pu.get('text', '').strip() + if not text: + continue + words = text.split() + confirmed_end = len(words) - 1 + old_pwe = pwe + if confirmed_end <= pwe: + if len(words) < pwe and pwe > 5: + trace_events.append( + { + 'idx': idx, + 'action': 'RESET_SKIPPED', + 'words_in_partial': len(words), + 'counter': pwe, + 'text_start': ' '.join(words[:5]), + } + ) + continue + delta = words[pwe:confirmed_end] + pwe = confirmed_end + if old_pwe > 5 and len(words) < old_pwe: + trace_events.append( + { + 'idx': idx, + 'action': 'CROSS_BOUNDARY_EMIT', + 'words_in_partial': len(words), + 'old_counter': old_pwe, + 'new_counter': pwe, + 'delta': ' '.join(delta[:10]), + 'text_start': ' '.join(words[:5]), + } + ) + elif len(trace_events) < 30: + trace_events.append( + { + 'idx': idx, + 'action': 'EMIT', + 'words_in_partial': len(words), + 'counter': pwe, + 'delta': ' '.join(delta[:8]), + } + ) + + for utt_idx, utt in enumerate(all_utterances): + text = utt.get('text', '').strip() + if not text: + continue + words = text.split() + if pwe > 0: + remaining = words[pwe:] + trace_events.append( + { + 'idx': f'UTT{utt_idx}', + 'action': 'UTT_RESET', + 'words_in_utterance': len(words), + 'old_counter': pwe, + 'remaining_words': len(remaining), + 'remaining_start': ' '.join(remaining[:5]) if remaining else '(none)', + } + ) + pwe = 0 + else: + trace_events.append( + { + 'idx': f'UTT{utt_idx}', + 'action': 'UTT_FULL', + 'words_in_utterance': len(words), + } + ) + + for ev in trace_events[:50]: + print(f' {ev}') + + # Count how many partials were skipped due to counter + pwe2 = 0 + skipped = 0 + emitted = 0 + cross_boundary = 0 + for pu in all_partials: + text = pu.get('text', '').strip() + if not text: + continue + words = text.split() + confirmed_end = len(words) - 1 + if confirmed_end <= pwe2: + skipped += 1 + continue + if pwe2 > 10 and len(words) < pwe2: + cross_boundary += 1 + pwe2 = confirmed_end + emitted += 1 + + print(f'\n--- DELTA COUNTER SUMMARY ---') + print(f'Partials processed: {len(all_partials)}') + print(f'Emitted: {emitted}') + print(f'Skipped (counter too high): {skipped}') + print(f'Cross-boundary emits (wrong offset): {cross_boundary}') + + out = { + 'raw_message_count': len(raw_messages), + 'utterance_count': len(all_utterances), + 'partial_count': len(all_partials), + 'utterance_wer': utt_wer, + 'delta_wer': delta_wer, + 'utterance_word_count': len(utterance_words), + 'delta_word_count': len(delta_words), + 'ref_word_count': len(ref_words), + 'utterance_text': utterance_text, + 'delta_text': delta_text, + 'ref_text': ref_text, + 'utterances': all_utterances, + 'partial_count_by_utterance': max_counter, + 'counter_resets': resets, + 'partials': [ + {'text': p.get('text', ''), 'start_ms': p.get('start_ms', 0), 'type': 'partial'} for p in all_partials + ], + } + out_path = '/tmp/modulate_debug_direct.json' + with open(out_path, 'w') as f: + json.dump(out, f, indent=2) + print(f'\nFull results saved to {out_path}') + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/backend/scripts/stt/r_ab_modulate_compare.py b/backend/scripts/stt/r_ab_modulate_compare.py new file mode 100644 index 0000000000..693d7d38f0 --- /dev/null +++ b/backend/scripts/stt/r_ab_modulate_compare.py @@ -0,0 +1,532 @@ +""" +A/B comparison: Direct Modulate API vs Backend /v4/listen (Modulate STT). + +Same audio, same pacing, same silence gaps. Compares WER to find implementation gaps. + +Usage: + # Start backend first with Modulate: + # cd backend && STT_SERVICE_MODELS=modulate-velma-2 python3 -m uvicorn main:app --port 8700 + # Then run: + cd backend && python3 scripts/stt/r_ab_modulate_compare.py + cd backend && python3 scripts/stt/r_ab_modulate_compare.py --duration 120 # shorter test +""" + +import argparse +import asyncio +import json +import os +import re +import subprocess +import sys +import time +import urllib.parse +from pathlib import Path + +import websockets + +MODULATE_API_KEY = os.getenv('MODULATE_API_KEY', '') +BACKEND_HOST = 'localhost' +BACKEND_PORT = 8700 +LIBRISPEECH_DIR = Path('/tmp/librispeech/LibriSpeech/test-clean') + +CHUNK_SIZE = 3200 +CHUNK_INTERVAL_S = 0.1 +SILENCE_BETWEEN_UTTERANCES_S = 1.0 +SAMPLE_RATE = 16000 +PUNCT_RE = re.compile(r'[^\w\s]', re.UNICODE) + + +def normalize(text): + text = PUNCT_RE.sub(' ', text).upper() + return ' '.join(text.split()) + + +def compute_wer(ref, hyp): + ref_words = ref.split() + hyp_words = hyp.split() + if not ref_words: + return 0.0 if not hyp_words else 1.0 + d = [[0] * (len(hyp_words) + 1) for _ in range(len(ref_words) + 1)] + for i in range(len(ref_words) + 1): + d[i][0] = i + for j in range(len(hyp_words) + 1): + d[0][j] = j + for i in range(1, len(ref_words) + 1): + for j in range(1, len(hyp_words) + 1): + if ref_words[i - 1] == hyp_words[j - 1]: + d[i][j] = d[i - 1][j - 1] + else: + d[i][j] = 1 + min(d[i - 1][j], d[i][j - 1], d[i - 1][j - 1]) + return d[len(ref_words)][len(hyp_words)] / len(ref_words) + + +def build_playlist(target_s): + playlist = [] + total_s = 0 + for reader_dir in sorted(LIBRISPEECH_DIR.iterdir()): + if not reader_dir.is_dir(): + continue + for chapter_dir in sorted(reader_dir.iterdir()): + if not chapter_dir.is_dir(): + continue + trans_file = list(chapter_dir.glob('*.trans.txt')) + if not trans_file: + continue + transcripts = {} + for line in trans_file[0].read_text().strip().split('\n'): + parts = line.strip().split(' ', 1) + if len(parts) == 2: + transcripts[parts[0]] = parts[1] + for flac in sorted(chapter_dir.glob('*.flac')): + uid = flac.stem + ref = transcripts.get(uid, '') + result = subprocess.run( + [ + 'ffprobe', + '-v', + 'error', + '-show_entries', + 'format=duration', + '-of', + 'default=noprint_wrappers=1:nokey=1', + str(flac), + ], + capture_output=True, + text=True, + ) + dur = float(result.stdout.strip()) if result.returncode == 0 else 5.0 + playlist.append({'flac': str(flac), 'ref': ref, 'uid': uid, 'duration_s': dur}) + total_s += dur + if total_s >= target_s: + return playlist, total_s + return playlist, total_s + + +def convert_to_pcm16(flac_path): + result = subprocess.run( + ['ffmpeg', '-y', '-i', flac_path, '-f', 's16le', '-ar', str(SAMPLE_RATE), '-ac', '1', 'pipe:1'], + capture_output=True, + ) + return result.stdout if result.returncode == 0 else None + + +async def send_audio(ws, playlist, label, send_eos=True): + """Send audio chunks with identical pacing. Returns (total_bytes, samples_sent).""" + total_bytes = 0 + samples_sent = 0 + silence = b'\x00' * (SAMPLE_RATE * 2 * int(SILENCE_BETWEEN_UTTERANCES_S)) + t0 = time.monotonic() + + for i, sample in enumerate(playlist): + pcm = convert_to_pcm16(sample['flac']) + if not pcm: + continue + + offset = 0 + while offset < len(pcm): + chunk = pcm[offset : offset + CHUNK_SIZE] + try: + await ws.send(chunk) + except Exception as e: + print(f' [{label}] Send error at sample {i}: {e}') + return total_bytes, samples_sent + offset += CHUNK_SIZE + total_bytes += len(chunk) + await asyncio.sleep(CHUNK_INTERVAL_S) + + samples_sent += 1 + + # 1s silence between utterances (identical for both tests) + try: + await ws.send(silence) + total_bytes += len(silence) + except Exception: + break + await asyncio.sleep(SILENCE_BETWEEN_UTTERANCES_S) + + if (i + 1) % 10 == 0: + elapsed = time.monotonic() - t0 + print(f' [{label}] [{elapsed:.0f}s] Sent {samples_sent}/{len(playlist)}, {total_bytes / 1024:.0f}KB') + + # Signal end of stream + if send_eos: + try: + await ws.send(b'') + except Exception: + pass + + elapsed = time.monotonic() - t0 + print(f' [{label}] All audio sent: {samples_sent} samples, {total_bytes / 1024:.0f}KB in {elapsed:.0f}s') + return total_bytes, samples_sent + + +async def test_direct_modulate(playlist): + """Test 1: Direct Modulate API — collect utterances and partials.""" + print('\n' + '=' * 70) + print('TEST A: Direct Modulate API') + print('=' * 70) + + params = { + 'api_key': MODULATE_API_KEY, + 'speaker_diarization': 'true', + 'partial_results': 'true', + 'sample_rate': str(SAMPLE_RATE), + 'audio_format': 's16le', + 'num_channels': '1', + 'language': 'en', + } + uri = f'wss://modulate-developer-apis.com/api/velma-2-stt-streaming?{urllib.parse.urlencode(params)}' + + ws = await websockets.connect(uri, ping_timeout=30, ping_interval=10, max_size=None) + print(' Connected to Modulate API') + + utterances = [] + partials = [] + last_partial_text = '' + done_event = asyncio.Event() + t0 = time.monotonic() + + async def recv(): + nonlocal last_partial_text + try: + async for raw in ws: + msg = json.loads(raw) + mt = msg.get('type', '') + elapsed = time.monotonic() - t0 + if mt == 'utterance': + utt = msg.get('utterance', msg) + text = utt.get('text', '').strip() + utterances.append(utt) + last_partial_text = '' + print(f' [DIRECT] [{elapsed:.1f}s] UTT #{len(utterances)}: {text[:80]}...') + elif mt == 'partial_utterance': + pu = msg.get('partial_utterance', msg) + partials.append(pu) + last_partial_text = pu.get('text', '').strip() + elif mt == 'done': + print(f' [DIRECT] [{elapsed:.1f}s] DONE: duration_ms={msg.get("duration_ms")}') + done_event.set() + break + elif mt == 'error': + print(f' [DIRECT] [{elapsed:.1f}s] ERROR: {msg}') + done_event.set() + break + except websockets.exceptions.ConnectionClosed as e: + print(f' [DIRECT] WS closed: {e}') + finally: + done_event.set() + + recv_task = asyncio.create_task(recv()) + await send_audio(ws, playlist, 'DIRECT', send_eos=False) + + print(' [DIRECT] Waiting for done event (up to 90s)...') + try: + await asyncio.wait_for(done_event.wait(), timeout=90) + except asyncio.TimeoutError: + print(' [DIRECT] Timed out waiting for done') + + recv_task.cancel() + try: + await ws.close() + except Exception: + pass + + # Build transcripts + utt_text = ' '.join(u.get('text', '') for u in utterances).strip() + # For partials: take the last partial text if no utterance followed it + # (simulates our backend's _flush_partial at done) + partial_final = utt_text + if last_partial_text and not utt_text.endswith(last_partial_text): + partial_final = (utt_text + ' ' + last_partial_text).strip() if utt_text else last_partial_text + + return { + 'utterances': utterances, + 'partials': partials, + 'utterance_text': utt_text, + 'utterance_plus_partial_text': partial_final, + 'utterance_count': len(utterances), + 'partial_count': len(partials), + } + + +async def test_backend_listen(playlist, port=BACKEND_PORT): + """Test 2: Backend /v4/listen with Modulate STT — collect segments.""" + print('\n' + '=' * 70) + print('TEST B: Backend /v4/listen (Modulate STT)') + print('=' * 70) + + params = { + 'language': 'en', + 'sample_rate': str(SAMPLE_RATE), + 'codec': 'pcm16', + 'channels': '1', + 'include_speech_profile': 'false', + 'conversation_timeout': '600', + } + url = f'ws://{BACKEND_HOST}:{port}/v4/listen?{"&".join(f"{k}={v}" for k, v in params.items())}' + + try: + ws = await asyncio.wait_for( + websockets.connect( + url, + additional_headers={'authorization': 'Bearer dev-token'}, + ping_timeout=None, + ping_interval=None, + max_size=None, + close_timeout=10, + ), + timeout=15, + ) + except Exception as e: + print(f' [BACKEND] Connection failed: {e}') + return None + + print(f' [BACKEND] Connected') + + segments = [] + segments_by_id = {} + ready = asyncio.Event() + recv_done = asyncio.Event() + t0 = time.monotonic() + + async def recv(): + try: + async for raw in ws: + if isinstance(raw, bytes) or raw == 'ping': + continue + elapsed = time.monotonic() - t0 + try: + msg = json.loads(raw) + except json.JSONDecodeError: + continue + + if isinstance(msg, list): + segs = msg + elif isinstance(msg, dict): + if msg.get('status') == 'ready': + ready.set() + print(f' [BACKEND] [{elapsed:.1f}s] Ready') + continue + if 'segments' in msg: + segs = msg['segments'] + else: + continue + else: + continue + + if not isinstance(segs, list): + continue + for seg in segs: + if not isinstance(seg, dict): + continue + text = seg.get('text', '').strip() + if text: + entry = { + 'id': seg.get('id', ''), + 'text': text, + 'speaker': seg.get('speaker', ''), + 'start': seg.get('start', 0), + 'end': seg.get('end', 0), + 'recv_ts': round(elapsed, 3), + } + segments.append(entry) + sid = seg.get('id', '') + if sid: + segments_by_id[sid] = entry + print(f' [BACKEND] [{elapsed:.1f}s] SEG: {text[:80]}') + except websockets.exceptions.ConnectionClosed: + pass + except Exception as e: + print(f' [BACKEND] Recv error: {e}') + finally: + recv_done.set() + + recv_task = asyncio.create_task(recv()) + + # Wait for ready (local dev takes ~60s due to Pusher retries) + try: + await asyncio.wait_for(ready.wait(), timeout=90) + except asyncio.TimeoutError: + print(' [BACKEND] No ready signal after 90s, proceeding...') + + await send_audio(ws, playlist, 'BACKEND', send_eos=False) + + # Wait for trailing results — match the drain timeout our implementation uses + print(' [BACKEND] Waiting 90s for trailing results...') + await asyncio.sleep(90) + + try: + await ws.close() + except Exception: + pass + await asyncio.sleep(2) + recv_task.cancel() + + final = list(segments_by_id.values()) if segments_by_id else segments + full_text = ' '.join(s['text'] for s in final).strip() + + return { + 'segments': final, + 'segment_updates': len(segments), + 'segment_final_count': len(final), + 'full_text': full_text, + } + + +def analyze_and_compare(ref_text, direct_result, backend_result): + """Compare WER and identify word-level differences.""" + ref_norm = normalize(ref_text) + ref_words = ref_norm.split() + + print('\n' + '=' * 70) + print('COMPARISON RESULTS') + print('=' * 70) + + # Direct Modulate + d_utt_text = normalize(direct_result['utterance_text']) + d_utt_plus = normalize(direct_result['utterance_plus_partial_text']) + d_utt_words = d_utt_text.split() + d_utt_plus_words = d_utt_plus.split() + d_utt_wer = compute_wer(ref_norm, d_utt_text) if d_utt_text else 1.0 + d_plus_wer = compute_wer(ref_norm, d_utt_plus) if d_utt_plus else 1.0 + + print(f'\n--- TEST A: Direct Modulate API ---') + print(f' Utterances received: {direct_result["utterance_count"]}') + print(f' Partials received: {direct_result["partial_count"]}') + print(f' Words (utterances only): {len(d_utt_words)} / {len(ref_words)}') + print(f' Words (utt + partial): {len(d_utt_plus_words)} / {len(ref_words)}') + print(f' WER (utterances only): {d_utt_wer * 100:.1f}%') + print(f' WER (utt + last partial): {d_plus_wer * 100:.1f}%') + + # Backend + if backend_result is None: + print(f'\n--- TEST B: Backend --- FAILED (no connection)') + return + + b_text = normalize(backend_result['full_text']) + b_words = b_text.split() + b_wer = compute_wer(ref_norm, b_text) if b_text else 1.0 + + print(f'\n--- TEST B: Backend /v4/listen (Modulate) ---') + print(f' Segment updates: {backend_result["segment_updates"]}') + print(f' Final segments: {backend_result["segment_final_count"]}') + print(f' Words received: {len(b_words)} / {len(ref_words)}') + print(f' WER: {b_wer * 100:.1f}%') + + # Delta + print(f'\n--- DELTA (B minus A) ---') + wer_delta = (b_wer - d_plus_wer) * 100 + word_delta = len(b_words) - len(d_utt_plus_words) + print( + f' WER difference: {wer_delta:+.1f}% ({"WORSE" if wer_delta > 0 else "BETTER" if wer_delta < 0 else "SAME"})' + ) + print(f' Word count difference: {word_delta:+d} words') + + if abs(wer_delta) < 2.0: + print(f'\n VERDICT: WER difference is minimal (<2%). No significant implementation flaw.') + elif wer_delta > 0: + print(f'\n VERDICT: Backend is {wer_delta:.1f}% worse than direct. Investigating...') + # Show word-level diff + _show_transcript_diff(d_utt_plus, b_text, ref_norm) + else: + print(f'\n VERDICT: Backend is {-wer_delta:.1f}% better than direct (combine_segments dedup may help).') + + # Show first few segments from each + print(f'\n--- TRANSCRIPT SAMPLES ---') + print(f' REF (first 200 chars): {ref_text[:200]}') + print(f' DIRECT (first 200 chars): {direct_result["utterance_plus_partial_text"][:200]}') + print(f' BACKEND (first 200 chars): {backend_result["full_text"][:200]}') + + # Save full results + out = { + 'ref_words': len(ref_words), + 'direct': { + 'utterance_count': direct_result['utterance_count'], + 'partial_count': direct_result['partial_count'], + 'utt_word_count': len(d_utt_words), + 'utt_plus_word_count': len(d_utt_plus_words), + 'wer_utt': round(d_utt_wer, 4), + 'wer_utt_plus': round(d_plus_wer, 4), + 'text': direct_result['utterance_plus_partial_text'], + }, + 'backend': { + 'segment_updates': backend_result['segment_updates'], + 'final_segments': backend_result['segment_final_count'], + 'word_count': len(b_words), + 'wer': round(b_wer, 4), + 'text': backend_result['full_text'], + }, + 'delta_wer_pct': round(wer_delta, 2), + 'delta_words': word_delta, + 'ref_text': ref_text, + } + out_path = '/tmp/modulate_ab_compare.json' + with open(out_path, 'w') as f: + json.dump(out, f, indent=2) + print(f'\nFull results saved to {out_path}') + + +def _show_transcript_diff(direct_text, backend_text, ref_text): + """Show where backend transcript diverges from direct.""" + d_words = direct_text.split() + b_words = backend_text.split() + r_words = ref_text.split() + + # Find words in direct but not in backend (potential word loss) + d_set = set(w.lower() for w in d_words) + b_set = set(w.lower() for w in b_words) + lost = d_set - b_set + gained = b_set - d_set + if lost: + print(f'\n Words in DIRECT but not in BACKEND (sample): {list(lost)[:20]}') + if gained: + print(f' Words in BACKEND but not in DIRECT (sample): {list(gained)[:20]}') + + +async def main(): + parser = argparse.ArgumentParser(description='A/B: Direct Modulate vs Backend Listen') + parser.add_argument('--duration', type=int, default=300, help='Target audio duration (seconds)') + parser.add_argument('--port', type=int, default=BACKEND_PORT, help='Backend port') + parser.add_argument('--direct-only', action='store_true', help='Run only direct test') + parser.add_argument('--backend-only', action='store_true', help='Run only backend test') + args = parser.parse_args() + + _port = args.port + + if not LIBRISPEECH_DIR.exists(): + print('ERROR: LibriSpeech not found. Run:') + print(' curl -L -o /tmp/test-clean.tar.gz https://www.openslr.org/resources/12/test-clean.tar.gz') + print(' cd /tmp && mkdir -p librispeech && tar xzf test-clean.tar.gz -C librispeech') + sys.exit(1) + + print(f'Building playlist (target: {args.duration}s)...') + playlist, total_s = build_playlist(args.duration) + ref_text = ' '.join(s['ref'] for s in playlist) + ref_words = ref_text.split() + print(f' {len(playlist)} utterances, {total_s:.1f}s, {len(ref_words)} ref words') + + direct_result = None + backend_result = None + + if not args.backend_only: + direct_result = await test_direct_modulate(playlist) + + if not args.direct_only: + backend_result = await test_backend_listen(playlist, port=_port) + + if direct_result and backend_result: + analyze_and_compare(ref_text, direct_result, backend_result) + elif direct_result: + ref_norm = normalize(ref_text) + d_text = normalize(direct_result['utterance_plus_partial_text']) + wer = compute_wer(ref_norm, d_text) + print(f'\nDirect only — WER: {wer * 100:.1f}%, Words: {len(d_text.split())} / {len(ref_words)}') + elif backend_result: + ref_norm = normalize(ref_text) + b_text = normalize(backend_result['full_text']) + wer = compute_wer(ref_norm, b_text) + print(f'\nBackend only — WER: {wer * 100:.1f}%, Words: {len(b_text.split())} / {len(ref_words)}') + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/backend/scripts/stt/r_modulate_stability.py b/backend/scripts/stt/r_modulate_stability.py new file mode 100644 index 0000000000..b763044895 --- /dev/null +++ b/backend/scripts/stt/r_modulate_stability.py @@ -0,0 +1,309 @@ +""" +Modulate WER stability test: same audio, same config, multiple runs. + +Sends identical audio with identical silence to Modulate N times. +If WER varies significantly between runs, the issue is Modulate's +non-determinism, not our test methodology. + +Usage: + cd backend && python3 scripts/stt/r_modulate_stability.py +""" + +import asyncio +import json +import os +import re +import subprocess +import sys +import time +import urllib.parse +from pathlib import Path + +import websockets + +MODULATE_API_KEY = os.getenv('MODULATE_API_KEY', '') +LIBRISPEECH_DIR = Path('/tmp/librispeech/LibriSpeech/test-clean') + +CHUNK_SIZE = 3200 +CHUNK_INTERVAL_S = 0.1 +SAMPLE_RATE = 16000 +PUNCT_RE = re.compile(r'[^\w\s]', re.UNICODE) + +RUNS_PER_CONFIG = 5 +CONFIGS = [ + {'silence_s': 5, 'label': '5s silence'}, + {'silence_s': 10, 'label': '10s silence'}, +] + + +def normalize(text): + text = PUNCT_RE.sub(' ', text).upper() + return ' '.join(text.split()) + + +def compute_wer(ref, hyp): + ref_words = ref.split() + hyp_words = hyp.split() + if not ref_words: + return 0.0 if not hyp_words else 1.0 + d = [[0] * (len(hyp_words) + 1) for _ in range(len(ref_words) + 1)] + for i in range(len(ref_words) + 1): + d[i][0] = i + for j in range(len(hyp_words) + 1): + d[0][j] = j + for i in range(1, len(ref_words) + 1): + for j in range(1, len(hyp_words) + 1): + if ref_words[i - 1] == hyp_words[j - 1]: + d[i][j] = d[i - 1][j - 1] + else: + d[i][j] = 1 + min(d[i - 1][j], d[i][j - 1], d[i - 1][j - 1]) + return d[len(ref_words)][len(hyp_words)] / len(ref_words) + + +def build_playlist(target_s): + playlist = [] + total_s = 0 + for reader_dir in sorted(LIBRISPEECH_DIR.iterdir()): + if not reader_dir.is_dir(): + continue + for chapter_dir in sorted(reader_dir.iterdir()): + if not chapter_dir.is_dir(): + continue + trans_file = list(chapter_dir.glob('*.trans.txt')) + if not trans_file: + continue + transcripts = {} + for line in trans_file[0].read_text().strip().split('\n'): + parts = line.strip().split(' ', 1) + if len(parts) == 2: + transcripts[parts[0]] = parts[1] + for flac in sorted(chapter_dir.glob('*.flac')): + uid = flac.stem + ref = transcripts.get(uid, '') + result = subprocess.run( + [ + 'ffprobe', + '-v', + 'error', + '-show_entries', + 'format=duration', + '-of', + 'default=noprint_wrappers=1:nokey=1', + str(flac), + ], + capture_output=True, + text=True, + ) + dur = float(result.stdout.strip()) if result.returncode == 0 else 5.0 + playlist.append({'flac': str(flac), 'ref': ref, 'uid': uid, 'duration_s': dur}) + total_s += dur + if total_s >= target_s: + return playlist, total_s + return playlist, total_s + + +def convert_to_pcm16(flac_path): + result = subprocess.run( + ['ffmpeg', '-y', '-i', flac_path, '-f', 's16le', '-ar', str(SAMPLE_RATE), '-ac', '1', 'pipe:1'], + capture_output=True, + ) + return result.stdout if result.returncode == 0 else None + + +_pcm_cache = {} + + +def get_pcm(flac_path): + if flac_path not in _pcm_cache: + _pcm_cache[flac_path] = convert_to_pcm16(flac_path) + return _pcm_cache[flac_path] + + +async def run_single(playlist, silence_s): + """Single run: send audio to Modulate, collect results.""" + params = { + 'api_key': MODULATE_API_KEY, + 'speaker_diarization': 'true', + 'partial_results': 'true', + 'sample_rate': str(SAMPLE_RATE), + 'audio_format': 's16le', + 'num_channels': '1', + 'language': 'en', + } + uri = f'wss://modulate-developer-apis.com/api/velma-2-stt-streaming?{urllib.parse.urlencode(params)}' + + ws = await websockets.connect(uri, ping_timeout=30, ping_interval=10, max_size=None) + + utterances = [] + last_partial_text = '' + done_event = asyncio.Event() + utt_order = [] + + async def recv(): + nonlocal last_partial_text + try: + async for raw in ws: + msg = json.loads(raw) + mt = msg.get('type', '') + if mt == 'utterance': + utt = msg.get('utterance', msg) + utterances.append(utt) + utt_order.append(utt.get('text', '')[:40]) + last_partial_text = '' + elif mt == 'partial_utterance': + pu = msg.get('partial_utterance', msg) + last_partial_text = pu.get('text', '').strip() + elif mt in ('done', 'error'): + done_event.set() + break + except websockets.exceptions.ConnectionClosed: + pass + finally: + done_event.set() + + recv_task = asyncio.create_task(recv()) + + silence_pcm = b'\x00' * int(SAMPLE_RATE * 2 * silence_s) if silence_s > 0 else b'' + total_bytes = 0 + + for i, sample in enumerate(playlist): + pcm = get_pcm(sample['flac']) + if not pcm: + continue + offset = 0 + while offset < len(pcm): + chunk = pcm[offset : offset + CHUNK_SIZE] + try: + await ws.send(chunk) + except Exception: + break + offset += CHUNK_SIZE + total_bytes += len(chunk) + await asyncio.sleep(CHUNK_INTERVAL_S) + + if i < len(playlist) - 1 and silence_pcm: + try: + await ws.send(silence_pcm) + total_bytes += len(silence_pcm) + except Exception: + break + await asyncio.sleep(silence_s) + + try: + await asyncio.wait_for(done_event.wait(), timeout=90) + except asyncio.TimeoutError: + pass + + recv_task.cancel() + try: + await ws.close() + except Exception: + pass + + utt_text = ' '.join(u.get('text', '') for u in utterances).strip() + full_text = utt_text + if last_partial_text and not utt_text.endswith(last_partial_text): + full_text = (utt_text + ' ' + last_partial_text).strip() if utt_text else last_partial_text + + return { + 'utterance_count': len(utterances), + 'full_text': full_text, + 'utt_order': utt_order, + } + + +async def main(): + print('Building playlist (target: 30s of speech)...') + playlist, total_s = build_playlist(30) + if not playlist: + print('ERROR: No LibriSpeech data.') + sys.exit(1) + + ref_text = ' '.join(s['ref'] for s in playlist) + ref_norm = normalize(ref_text) + ref_words = ref_norm.split() + print(f' {len(playlist)} utterances, {total_s:.1f}s speech, {len(ref_words)} ref words') + print(f' Ref: {ref_text[:120]}...\n') + + for s in playlist: + get_pcm(s['flac']) + + all_results = {} + + for config in CONFIGS: + silence_s = config['silence_s'] + label = config['label'] + print(f'{"=" * 70}') + print(f'{label} — {RUNS_PER_CONFIG} runs') + print(f'{"=" * 70}') + + runs = [] + for r in range(RUNS_PER_CONFIG): + print(f' Run {r + 1}/{RUNS_PER_CONFIG}...', end=' ', flush=True) + result = await run_single(playlist, silence_s) + hyp_norm = normalize(result['full_text']) + wer = compute_wer(ref_norm, hyp_norm) + words = len(hyp_norm.split()) if hyp_norm else 0 + + run_data = { + 'run': r + 1, + 'wer': wer, + 'words': words, + 'utts': result['utterance_count'], + 'utt_order': result['utt_order'], + 'text_sample': result['full_text'][:100], + } + runs.append(run_data) + print(f'WER={wer * 100:.1f}% words={words}/{len(ref_words)} utts={result["utterance_count"]}') + + # Brief pause between runs + await asyncio.sleep(2) + + wers = [r['wer'] for r in runs] + words_list = [r['words'] for r in runs] + avg_wer = sum(wers) / len(wers) + min_wer = min(wers) + max_wer = max(wers) + spread = max_wer - min_wer + avg_words = sum(words_list) / len(words_list) + + print(f'\n --- {label} summary ---') + print( + f' WER: avg={avg_wer * 100:.1f}% min={min_wer * 100:.1f}% max={max_wer * 100:.1f}% spread={spread * 100:.1f}%' + ) + print(f' Words: avg={avg_words:.0f}/{len(ref_words)}') + + # Show utterance order per run + print(f' Utterance arrival order:') + for r in runs: + order_str = ' → '.join(r['utt_order'][:4]) + print(f' Run {r["run"]}: [{r["utts"]} utts] {order_str}') + + all_results[label] = { + 'runs': runs, + 'avg_wer': avg_wer, + 'min_wer': min_wer, + 'max_wer': max_wer, + 'spread': spread, + 'avg_words': avg_words, + } + print() + + # Final verdict + print(f'{"=" * 70}') + print('VERDICT: Modulate WER Stability') + print(f'{"=" * 70}') + for label, data in all_results.items(): + stable = data['spread'] < 0.05 + status = 'STABLE (spread < 5%)' if stable else f'UNSTABLE (spread = {data["spread"] * 100:.1f}%)' + print( + f' {label}: avg WER = {data["avg_wer"] * 100:.1f}%, range = [{data["min_wer"] * 100:.1f}% - {data["max_wer"] * 100:.1f}%] → {status}' + ) + + with open('/tmp/modulate_stability.json', 'w') as f: + json.dump(all_results, f, indent=2, default=str) + print(f'\nRaw results saved to /tmp/modulate_stability.json') + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/backend/scripts/stt/r_silence_compression_test.py b/backend/scripts/stt/r_silence_compression_test.py new file mode 100644 index 0000000000..eab8a05937 --- /dev/null +++ b/backend/scripts/stt/r_silence_compression_test.py @@ -0,0 +1,319 @@ +""" +Test: Can we compress silence sent to Modulate to save costs? + +Sends identical speech audio with varying silence durations between utterances: + - 0s (no silence — back-to-back speech) + - 0.5s + - 1s + - 5s + - 10s (baseline — generous padding) + +If WER stays consistent across all silence durations, we can compress silence +and save bandwidth/costs while keeping Modulate's continuous stream intact. + +Usage: + cd backend && python3 scripts/stt/r_silence_compression_test.py +""" + +import asyncio +import json +import os +import re +import subprocess +import sys +import time +import urllib.parse +from pathlib import Path + +import websockets + +MODULATE_API_KEY = os.getenv('MODULATE_API_KEY', '') +LIBRISPEECH_DIR = Path('/tmp/librispeech/LibriSpeech/test-clean') + +CHUNK_SIZE = 3200 +CHUNK_INTERVAL_S = 0.1 +SAMPLE_RATE = 16000 +PUNCT_RE = re.compile(r'[^\w\s]', re.UNICODE) + +SILENCE_DURATIONS = [0, 0.5, 1, 5, 10] + + +def normalize(text): + text = PUNCT_RE.sub(' ', text).upper() + return ' '.join(text.split()) + + +def compute_wer(ref, hyp): + ref_words = ref.split() + hyp_words = hyp.split() + if not ref_words: + return 0.0 if not hyp_words else 1.0 + d = [[0] * (len(hyp_words) + 1) for _ in range(len(ref_words) + 1)] + for i in range(len(ref_words) + 1): + d[i][0] = i + for j in range(len(hyp_words) + 1): + d[0][j] = j + for i in range(1, len(ref_words) + 1): + for j in range(1, len(hyp_words) + 1): + if ref_words[i - 1] == hyp_words[j - 1]: + d[i][j] = d[i - 1][j - 1] + else: + d[i][j] = 1 + min(d[i - 1][j], d[i][j - 1], d[i - 1][j - 1]) + return d[len(ref_words)][len(hyp_words)] / len(ref_words) + + +def build_playlist(target_s): + playlist = [] + total_s = 0 + for reader_dir in sorted(LIBRISPEECH_DIR.iterdir()): + if not reader_dir.is_dir(): + continue + for chapter_dir in sorted(reader_dir.iterdir()): + if not chapter_dir.is_dir(): + continue + trans_file = list(chapter_dir.glob('*.trans.txt')) + if not trans_file: + continue + transcripts = {} + for line in trans_file[0].read_text().strip().split('\n'): + parts = line.strip().split(' ', 1) + if len(parts) == 2: + transcripts[parts[0]] = parts[1] + for flac in sorted(chapter_dir.glob('*.flac')): + uid = flac.stem + ref = transcripts.get(uid, '') + result = subprocess.run( + [ + 'ffprobe', + '-v', + 'error', + '-show_entries', + 'format=duration', + '-of', + 'default=noprint_wrappers=1:nokey=1', + str(flac), + ], + capture_output=True, + text=True, + ) + dur = float(result.stdout.strip()) if result.returncode == 0 else 5.0 + playlist.append({'flac': str(flac), 'ref': ref, 'uid': uid, 'duration_s': dur}) + total_s += dur + if total_s >= target_s: + return playlist, total_s + return playlist, total_s + + +def convert_to_pcm16(flac_path): + result = subprocess.run( + ['ffmpeg', '-y', '-i', flac_path, '-f', 's16le', '-ar', str(SAMPLE_RATE), '-ac', '1', 'pipe:1'], + capture_output=True, + ) + return result.stdout if result.returncode == 0 else None + + +# Pre-convert all audio once +_pcm_cache = {} + + +def get_pcm(flac_path): + if flac_path not in _pcm_cache: + _pcm_cache[flac_path] = convert_to_pcm16(flac_path) + return _pcm_cache[flac_path] + + +async def test_with_silence(playlist, silence_s, ref_text): + """Send audio to Modulate with specific silence duration between utterances.""" + params = { + 'api_key': MODULATE_API_KEY, + 'speaker_diarization': 'true', + 'partial_results': 'true', + 'sample_rate': str(SAMPLE_RATE), + 'audio_format': 's16le', + 'num_channels': '1', + 'language': 'en', + } + uri = f'wss://modulate-developer-apis.com/api/velma-2-stt-streaming?{urllib.parse.urlencode(params)}' + + ws = await websockets.connect(uri, ping_timeout=30, ping_interval=10, max_size=None) + + utterances = [] + partials = [] + last_partial_text = '' + done_event = asyncio.Event() + t0 = time.monotonic() + + async def recv(): + nonlocal last_partial_text + try: + async for raw in ws: + msg = json.loads(raw) + mt = msg.get('type', '') + if mt == 'utterance': + utt = msg.get('utterance', msg) + utterances.append(utt) + last_partial_text = '' + elif mt == 'partial_utterance': + pu = msg.get('partial_utterance', msg) + partials.append(pu) + last_partial_text = pu.get('text', '').strip() + elif mt == 'done': + done_event.set() + break + elif mt == 'error': + done_event.set() + break + except websockets.exceptions.ConnectionClosed: + pass + finally: + done_event.set() + + recv_task = asyncio.create_task(recv()) + + # Send audio with specified silence duration + total_bytes = 0 + silence_bytes = b'\x00' * int(SAMPLE_RATE * 2 * silence_s) if silence_s > 0 else b'' + total_silence_bytes = 0 + + for i, sample in enumerate(playlist): + pcm = get_pcm(sample['flac']) + if not pcm: + continue + + # Send speech audio in chunks + offset = 0 + while offset < len(pcm): + chunk = pcm[offset : offset + CHUNK_SIZE] + try: + await ws.send(chunk) + except Exception: + break + offset += CHUNK_SIZE + total_bytes += len(chunk) + await asyncio.sleep(CHUNK_INTERVAL_S) + + # Send silence between utterances (except after last) + if i < len(playlist) - 1 and silence_bytes: + try: + await ws.send(silence_bytes) + total_bytes += len(silence_bytes) + total_silence_bytes += len(silence_bytes) + except Exception: + break + # Pace silence to real-time + await asyncio.sleep(silence_s) + + elapsed_send = time.monotonic() - t0 + + # Wait for results + try: + await asyncio.wait_for(done_event.wait(), timeout=90) + except asyncio.TimeoutError: + pass + + recv_task.cancel() + try: + await ws.close() + except Exception: + pass + + elapsed_total = time.monotonic() - t0 + + # Build transcript + utt_text = ' '.join(u.get('text', '') for u in utterances).strip() + # Include last partial if not already in an utterance + full_text = utt_text + if last_partial_text and not utt_text.endswith(last_partial_text): + full_text = (utt_text + ' ' + last_partial_text).strip() if utt_text else last_partial_text + + ref_norm = normalize(ref_text) + hyp_norm = normalize(full_text) + wer = compute_wer(ref_norm, hyp_norm) + word_count = len(hyp_norm.split()) if hyp_norm else 0 + ref_count = len(ref_norm.split()) + + speech_bytes = total_bytes - total_silence_bytes + savings_pct = (total_silence_bytes / total_bytes * 100) if total_bytes > 0 else 0 + + return { + 'silence_s': silence_s, + 'wer': wer, + 'word_count': word_count, + 'ref_count': ref_count, + 'utterances': len(utterances), + 'partials': len(partials), + 'total_bytes': total_bytes, + 'speech_bytes': speech_bytes, + 'silence_bytes': total_silence_bytes, + 'savings_pct': savings_pct, + 'send_time': elapsed_send, + 'total_time': elapsed_total, + 'text_sample': full_text[:150], + } + + +async def main(): + print('Building playlist (target: 30s)...') + playlist, total_s = build_playlist(30) + if not playlist: + print('ERROR: No LibriSpeech data found. Download first.') + sys.exit(1) + + ref_text = ' '.join(s['ref'] for s in playlist) + ref_norm = normalize(ref_text) + ref_words = len(ref_norm.split()) + print(f' {len(playlist)} utterances, {total_s:.1f}s speech, {ref_words} ref words') + + # Pre-cache PCM + for s in playlist: + get_pcm(s['flac']) + + results = [] + + for silence_s in SILENCE_DURATIONS: + label = f'{silence_s}s' if silence_s > 0 else '0s (back-to-back)' + print(f'\n{"=" * 60}') + print(f'Testing silence = {label}') + print(f'{"=" * 60}') + + result = await test_with_silence(playlist, silence_s, ref_text) + results.append(result) + + print(f' WER: {result["wer"] * 100:.1f}%') + print(f' Words: {result["word_count"]} / {result["ref_count"]}') + print(f' Utterances: {result["utterances"]}') + print(f' Total bytes: {result["total_bytes"] / 1024:.0f} KB') + print(f' Silence bytes:{result["silence_bytes"] / 1024:.0f} KB ({result["savings_pct"]:.1f}% of total)') + print(f' Send time: {result["send_time"]:.1f}s') + print(f' Text sample: {result["text_sample"]}') + + # Summary table + print(f'\n{"=" * 60}') + print('SUMMARY: Silence Compression Results') + print(f'{"=" * 60}') + print(f'{"Silence":>10} {"WER":>8} {"Words":>8} {"UTTs":>6} {"Total KB":>10} {"Silence KB":>12} {"Savings":>10}') + print('-' * 70) + for r in results: + label = f'{r["silence_s"]}s' + print( + f'{label:>10} {r["wer"] * 100:>7.1f}% {r["word_count"]:>5}/{r["ref_count"]:<3}' + f' {r["utterances"]:>5} {r["total_bytes"] / 1024:>9.0f} {r["silence_bytes"] / 1024:>11.0f}' + f' {r["savings_pct"]:>9.1f}%' + ) + + # Verdict + baseline = results[-1] # 10s silence = baseline + print(f'\nBaseline (10s silence): WER = {baseline["wer"] * 100:.1f}%') + for r in results[:-1]: + delta = (r['wer'] - baseline['wer']) * 100 + direction = 'worse' if delta > 0 else 'better' if delta < 0 else 'same' + print(f' {r["silence_s"]}s silence: WER = {r["wer"] * 100:.1f}% ({delta:+.1f}% {direction})') + + # Save raw results + with open('/tmp/modulate_silence_test.json', 'w') as f: + json.dump(results, f, indent=2) + print(f'\nRaw results saved to /tmp/modulate_silence_test.json') + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/backend/scripts/stt/r_silence_sweep.py b/backend/scripts/stt/r_silence_sweep.py new file mode 100644 index 0000000000..c361ee2b0f --- /dev/null +++ b/backend/scripts/stt/r_silence_sweep.py @@ -0,0 +1,337 @@ +""" +Find the minimum silence duration that matches no-VAD WER with Modulate. + +1. Baseline: send audio with generous 15s silence (simulates no-VAD / passthrough) +2. Sweep: test 0s, 0.5s, 1s, 2s, 3s, 5s, 7s, 10s +3. Find which silence duration matches baseline WER + +Usage: + cd backend && python3 scripts/stt/r_silence_sweep.py +""" + +import asyncio +import json +import os +import re +import subprocess +import sys +import time +import urllib.parse +from pathlib import Path + +import websockets + +MODULATE_API_KEY = os.getenv('MODULATE_API_KEY', '') +LIBRISPEECH_DIR = Path('/tmp/librispeech/LibriSpeech/test-clean') + +CHUNK_SIZE = 3200 +CHUNK_INTERVAL_S = 0.1 +SAMPLE_RATE = 16000 +PUNCT_RE = re.compile(r'[^\w\s]', re.UNICODE) + +SWEEP_DURATIONS = [0, 0.5, 1, 2, 3, 5, 7, 10, 15] + + +def normalize(text): + text = PUNCT_RE.sub(' ', text).upper() + return ' '.join(text.split()) + + +def compute_wer(ref, hyp): + ref_words = ref.split() + hyp_words = hyp.split() + if not ref_words: + return 0.0 if not hyp_words else 1.0 + d = [[0] * (len(hyp_words) + 1) for _ in range(len(ref_words) + 1)] + for i in range(len(ref_words) + 1): + d[i][0] = i + for j in range(len(hyp_words) + 1): + d[0][j] = j + for i in range(1, len(ref_words) + 1): + for j in range(1, len(hyp_words) + 1): + if ref_words[i - 1] == hyp_words[j - 1]: + d[i][j] = d[i - 1][j - 1] + else: + d[i][j] = 1 + min(d[i - 1][j], d[i][j - 1], d[i - 1][j - 1]) + return d[len(ref_words)][len(hyp_words)] / len(ref_words) + + +def build_playlist(target_s): + playlist = [] + total_s = 0 + for reader_dir in sorted(LIBRISPEECH_DIR.iterdir()): + if not reader_dir.is_dir(): + continue + for chapter_dir in sorted(reader_dir.iterdir()): + if not chapter_dir.is_dir(): + continue + trans_file = list(chapter_dir.glob('*.trans.txt')) + if not trans_file: + continue + transcripts = {} + for line in trans_file[0].read_text().strip().split('\n'): + parts = line.strip().split(' ', 1) + if len(parts) == 2: + transcripts[parts[0]] = parts[1] + for flac in sorted(chapter_dir.glob('*.flac')): + uid = flac.stem + ref = transcripts.get(uid, '') + result = subprocess.run( + [ + 'ffprobe', + '-v', + 'error', + '-show_entries', + 'format=duration', + '-of', + 'default=noprint_wrappers=1:nokey=1', + str(flac), + ], + capture_output=True, + text=True, + ) + dur = float(result.stdout.strip()) if result.returncode == 0 else 5.0 + playlist.append({'flac': str(flac), 'ref': ref, 'uid': uid, 'duration_s': dur}) + total_s += dur + if total_s >= target_s: + return playlist, total_s + return playlist, total_s + + +def convert_to_pcm16(flac_path): + result = subprocess.run( + ['ffmpeg', '-y', '-i', flac_path, '-f', 's16le', '-ar', str(SAMPLE_RATE), '-ac', '1', 'pipe:1'], + capture_output=True, + ) + return result.stdout if result.returncode == 0 else None + + +_pcm_cache = {} + + +def get_pcm(flac_path): + if flac_path not in _pcm_cache: + _pcm_cache[flac_path] = convert_to_pcm16(flac_path) + return _pcm_cache[flac_path] + + +async def run_test(playlist, silence_s): + """Send audio to Modulate with specified silence between utterances.""" + params = { + 'api_key': MODULATE_API_KEY, + 'speaker_diarization': 'true', + 'partial_results': 'true', + 'sample_rate': str(SAMPLE_RATE), + 'audio_format': 's16le', + 'num_channels': '1', + 'language': 'en', + } + uri = f'wss://modulate-developer-apis.com/api/velma-2-stt-streaming?{urllib.parse.urlencode(params)}' + + ws = await websockets.connect(uri, ping_timeout=30, ping_interval=10, max_size=None) + + utterances = [] + last_partial_text = '' + done_event = asyncio.Event() + + async def recv(): + nonlocal last_partial_text + try: + async for raw in ws: + msg = json.loads(raw) + mt = msg.get('type', '') + if mt == 'utterance': + utt = msg.get('utterance', msg) + utterances.append(utt) + last_partial_text = '' + elif mt == 'partial_utterance': + pu = msg.get('partial_utterance', msg) + last_partial_text = pu.get('text', '').strip() + elif mt in ('done', 'error'): + done_event.set() + break + except websockets.exceptions.ConnectionClosed: + pass + finally: + done_event.set() + + recv_task = asyncio.create_task(recv()) + + total_bytes = 0 + silence_bytes_total = 0 + silence_pcm = b'\x00' * int(SAMPLE_RATE * 2 * silence_s) if silence_s > 0 else b'' + + for i, sample in enumerate(playlist): + pcm = get_pcm(sample['flac']) + if not pcm: + continue + offset = 0 + while offset < len(pcm): + chunk = pcm[offset : offset + CHUNK_SIZE] + try: + await ws.send(chunk) + except Exception: + break + offset += CHUNK_SIZE + total_bytes += len(chunk) + await asyncio.sleep(CHUNK_INTERVAL_S) + + if i < len(playlist) - 1 and silence_pcm: + try: + await ws.send(silence_pcm) + total_bytes += len(silence_pcm) + silence_bytes_total += len(silence_pcm) + except Exception: + break + await asyncio.sleep(silence_s) + + # Wait for trailing results + try: + await asyncio.wait_for(done_event.wait(), timeout=90) + except asyncio.TimeoutError: + pass + + recv_task.cancel() + try: + await ws.close() + except Exception: + pass + + utt_text = ' '.join(u.get('text', '') for u in utterances).strip() + full_text = utt_text + if last_partial_text and not utt_text.endswith(last_partial_text): + full_text = (utt_text + ' ' + last_partial_text).strip() if utt_text else last_partial_text + + return { + 'silence_s': silence_s, + 'utterance_count': len(utterances), + 'full_text': full_text, + 'total_bytes': total_bytes, + 'silence_bytes': silence_bytes_total, + } + + +async def main(): + print('Building playlist (target: 30s of speech)...') + playlist, total_s = build_playlist(30) + if not playlist: + print('ERROR: No LibriSpeech data. Run the download first.') + sys.exit(1) + + ref_text = ' '.join(s['ref'] for s in playlist) + ref_norm = normalize(ref_text) + ref_words = ref_norm.split() + print(f' {len(playlist)} utterances, {total_s:.1f}s speech, {len(ref_words)} ref words\n') + + for s in playlist: + get_pcm(s['flac']) + + # Run baseline first (15s = no-VAD equivalent) + print('=' * 70) + print('BASELINE: 15s silence (no-VAD equivalent)') + print('=' * 70) + baseline = await run_test(playlist, 15) + b_norm = normalize(baseline['full_text']) + b_wer = compute_wer(ref_norm, b_norm) + b_words = len(b_norm.split()) if b_norm else 0 + print(f' WER: {b_wer * 100:.1f}%') + print(f' Words: {b_words}/{len(ref_words)}') + print(f' UTTs: {baseline["utterance_count"]}') + print(f' Bytes: {baseline["total_bytes"] / 1024:.0f} KB') + print(f' Text: {baseline["full_text"][:120]}...') + + # Sweep + rows = [] + rows.append( + { + 'silence_s': 15, + 'wer': b_wer, + 'words': b_words, + 'utts': baseline['utterance_count'], + 'total_kb': baseline['total_bytes'] / 1024, + 'silence_kb': baseline['silence_bytes'] / 1024, + 'delta_wer': 0, + 'text': baseline['full_text'][:100], + } + ) + + for s in [d for d in SWEEP_DURATIONS if d != 15]: + label = f'{s}s' if s > 0 else '0s' + print(f'\n--- Testing {label} silence ---') + result = await run_test(playlist, s) + hyp_norm = normalize(result['full_text']) + wer = compute_wer(ref_norm, hyp_norm) + words = len(hyp_norm.split()) if hyp_norm else 0 + delta = wer - b_wer + + rows.append( + { + 'silence_s': s, + 'wer': wer, + 'words': words, + 'utts': result['utterance_count'], + 'total_kb': result['total_bytes'] / 1024, + 'silence_kb': result['silence_bytes'] / 1024, + 'delta_wer': delta, + 'text': result['full_text'][:100], + } + ) + + status = 'MATCH' if abs(delta) < 0.03 else ('CLOSE' if abs(delta) < 0.08 else 'MISS') + print(f' WER: {wer * 100:.1f}% (delta: {delta * 100:+.1f}%) [{status}]') + print(f' Words: {words}/{len(ref_words)}, UTTs: {result["utterance_count"]}') + print(f' Total: {result["total_bytes"] / 1024:.0f} KB, Silence: {result["silence_bytes"] / 1024:.0f} KB') + + # Sort by silence duration for table + rows.sort(key=lambda r: r['silence_s']) + + print(f'\n{"=" * 70}') + print(f'RESULTS (baseline = 15s silence, WER = {b_wer * 100:.1f}%)') + print(f'{"=" * 70}') + print( + f'{"Silence":>8} {"WER":>7} {"Delta":>8} {"Words":>8} {"UTTs":>5} {"Total KB":>9} {"SilKB":>7} {"Savings":>8} {"Status":>7}' + ) + print('-' * 75) + + for r in rows: + savings = (1 - r['total_kb'] / rows[-1]['total_kb']) * 100 if rows[-1]['total_kb'] > 0 else 0 + # Compare to 15s baseline total bytes + baseline_kb = [x for x in rows if x['silence_s'] == 15][0]['total_kb'] + savings = (1 - r['total_kb'] / baseline_kb) * 100 if baseline_kb > 0 else 0 + + status = 'MATCH' if abs(r['delta_wer']) < 0.03 else ('CLOSE' if abs(r['delta_wer']) < 0.08 else 'MISS') + arrow = '<<<' if status == 'MATCH' and r['silence_s'] != 15 else '' + print( + f'{r["silence_s"]:>7}s {r["wer"] * 100:>6.1f}% {r["delta_wer"] * 100:>+7.1f}%' + f' {r["words"]:>5}/{len(ref_words):<3} {r["utts"]:>4}' + f' {r["total_kb"]:>8.0f} {r["silence_kb"]:>6.0f} {savings:>7.1f}%' + f' {status:>6} {arrow}' + ) + + # Find best match + non_baseline = [r for r in rows if r['silence_s'] != 15] + matches = [r for r in non_baseline if abs(r['delta_wer']) < 0.03] + close = [r for r in non_baseline if 0.03 <= abs(r['delta_wer']) < 0.08] + + print(f'\nBaseline WER (no-VAD): {b_wer * 100:.1f}%') + if matches: + best = min(matches, key=lambda r: r['silence_s']) + baseline_kb = [x for x in rows if x['silence_s'] == 15][0]['total_kb'] + savings = (1 - best['total_kb'] / baseline_kb) * 100 + print(f'ANSWER: {best["silence_s"]}s silence MATCHES baseline WER ({best["wer"] * 100:.1f}%)') + print(f' Bandwidth savings vs no-VAD: {savings:.1f}%') + elif close: + best = min(close, key=lambda r: abs(r['delta_wer'])) + print( + f'CLOSEST: {best["silence_s"]}s silence ({best["wer"] * 100:.1f}%, delta {best["delta_wer"] * 100:+.1f}%)' + ) + else: + print('No silence duration matched baseline WER within 3%.') + + with open('/tmp/modulate_silence_sweep.json', 'w') as f: + json.dump(rows, f, indent=2, default=str) + print(f'\nRaw results saved to /tmp/modulate_silence_sweep.json') + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/backend/test.sh b/backend/test.sh index 99a161e368..833c2e1baa 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -57,6 +57,7 @@ pytest tests/unit/test_people_conversations_500s.py -v pytest tests/unit/test_firestore_read_ops_cache.py -v pytest tests/unit/test_ws_auth_handshake.py -v pytest tests/unit/test_streaming_deepgram_backoff.py -v +pytest tests/unit/test_modulate_stt.py -v pytest tests/unit/test_batch_upload_storage.py -v pytest tests/unit/test_action_item_date_validation.py -v pytest tests/unit/test_action_item_dedup.py -v diff --git a/backend/tests/unit/test_modulate_stt.py b/backend/tests/unit/test_modulate_stt.py new file mode 100644 index 0000000000..1009586120 --- /dev/null +++ b/backend/tests/unit/test_modulate_stt.py @@ -0,0 +1,928 @@ +import asyncio +import json +import struct +import sys +import threading +import unittest +from io import BytesIO +from unittest.mock import AsyncMock, MagicMock, patch + +# Stub heavy deps before import +for mod in [ + 'google.cloud', + 'google.cloud.firestore', + 'google.cloud.firestore_v1', + 'google.cloud.storage', + 'google.auth', + 'google.auth.transport', + 'google.auth.transport.requests', + 'google.api_core', + 'google.api_core.exceptions', + 'firebase_admin', + 'firebase_admin.auth', + 'firebase_admin.firestore', + 'database.redis_db', + 'database.auth', + 'utils.other.storage', + 'deepgram', + 'deepgram.clients.live.v1', + 'fal_client', + 'opuslib', + 'silero_vad', +]: + if mod not in sys.modules: + sys.modules[mod] = MagicMock() + +# Stub deepgram classes needed at import time +sys.modules['deepgram'].DeepgramClient = MagicMock +sys.modules['deepgram'].DeepgramClientOptions = MagicMock +sys.modules['deepgram'].LiveTranscriptionEvents = MagicMock() +sys.modules['deepgram.clients.live.v1'].LiveOptions = MagicMock + +from utils.stt.streaming import ( + STTService, + SafeModulateSocket, + _build_wav_header, + get_stt_service_for_language, + modulate_languages, +) + + +class TestSTTServiceEnum(unittest.TestCase): + def test_modulate_enum_exists(self): + self.assertEqual(STTService.modulate.value, 'modulate') + + def test_get_model_name_modulate(self): + self.assertEqual(STTService.get_model_name(STTService.modulate), 'modulate_streaming') + + def test_get_model_name_deepgram(self): + self.assertEqual(STTService.get_model_name(STTService.deepgram), 'deepgram_streaming') + + +class TestLanguageRouting(unittest.TestCase): + @patch('utils.stt.streaming.stt_service_models', ['modulate-velma-2']) + def test_modulate_routing_english(self): + service, lang, model = get_stt_service_for_language('en') + self.assertEqual(service, STTService.modulate) + self.assertEqual(lang, 'en') + self.assertEqual(model, 'velma-2') + + @patch('utils.stt.streaming.stt_service_models', ['modulate-velma-2']) + def test_modulate_routing_multi(self): + service, lang, model = get_stt_service_for_language('multi') + self.assertEqual(service, STTService.modulate) + self.assertEqual(lang, 'multi') + + @patch('utils.stt.streaming.stt_service_models', ['modulate-velma-2']) + def test_modulate_unsupported_lang_fallback(self): + service, lang, model = get_stt_service_for_language('xx-unsupported') + self.assertEqual(service, STTService.deepgram) + self.assertEqual(lang, 'en') + self.assertEqual(model, 'nova-3') + + @patch('utils.stt.streaming.stt_service_models', ['dg-nova-3']) + def test_deepgram_default(self): + service, lang, model = get_stt_service_for_language('en') + self.assertEqual(service, STTService.deepgram) + + @patch('utils.stt.streaming.stt_service_models', ['dg-nova-3', 'modulate-velma-2']) + def test_deepgram_first_wins(self): + service, lang, model = get_stt_service_for_language('en') + self.assertEqual(service, STTService.deepgram) + + @patch('utils.stt.streaming.stt_service_models', ['modulate-velma-2', 'dg-nova-3']) + def test_modulate_first_wins(self): + service, lang, model = get_stt_service_for_language('en') + self.assertEqual(service, STTService.modulate) + + @patch('utils.stt.streaming.stt_service_models', ['dg-nova-3', 'modulate-velma-2']) + def test_dg_unsupported_falls_through_to_modulate(self): + service, lang, model = get_stt_service_for_language('af') + self.assertEqual(service, STTService.modulate) + self.assertEqual(lang, 'af') + self.assertEqual(model, 'velma-2') + + +class TestWAVHeader(unittest.TestCase): + def test_wav_header_valid(self): + header = _build_wav_header(16000) + self.assertTrue(header.startswith(b'RIFF')) + self.assertIn(b'WAVE', header) + self.assertIn(b'fmt ', header) + + def test_wav_header_sample_rate(self): + header = _build_wav_header(48000) + fmt_offset = header.index(b'fmt ') + 4 + fmt_offset += 4 # skip chunk size + fmt_offset += 2 # skip audio format + fmt_offset += 2 # skip num channels + sr = struct.unpack_from(' Union[List[dict], Tuple[List[dict], str]]: + logger.info(f'modulate_prerecorded_from_bytes bytes_len={len(audio_bytes)} {sample_rate} {diarize} {attempts}') + + api_key = os.getenv('MODULATE_API_KEY') + if not api_key: + raise ValueError('MODULATE_API_KEY environment variable is not set') + + try: + url = 'https://modulate-developer-apis.com/api/velma-2-stt-batch' + headers = {'X-API-Key': api_key} + files = {'upload_file': ('audio.wav', BytesIO(audio_bytes), 'audio/wav')} + data = {'speaker_diarization': str(diarize).lower()} + + with httpx.Client(timeout=300) as client: + response = client.post(url, headers=headers, files=files, data=data) + response.raise_for_status() + result = response.json() + + utterances = result.get('utterances', []) + if not utterances: + if return_language: + return [], 'en' + return [] + + words = [] + detected_language = 'en' + for utt in utterances: + text = utt.get('text', '').strip() + if not text: + continue + + start_ms = utt.get('start_ms', 0) + duration_ms = utt.get('duration_ms', 0) + start = start_ms / 1000.0 + end = (start_ms + duration_ms) / 1000.0 + + raw_speaker = utt.get('speaker') + if isinstance(raw_speaker, int) and raw_speaker >= 1: + speaker_idx = raw_speaker - 1 + else: + speaker_idx = 0 + speaker = f'SPEAKER_{speaker_idx:02d}' + + words.append({'timestamp': [start, end], 'speaker': speaker, 'text': text}) + + lang = utt.get('language') + if lang: + detected_language = lang + + if return_language: + return words, detected_language + + return words + + except Exception as e: + logger.error(f'Modulate prerecorded error: {e}') + if attempts < 2: + return modulate_prerecorded_from_bytes(audio_bytes, sample_rate, diarize, attempts + 1, return_language) + raise RuntimeError(f'Modulate transcription failed after {attempts + 1} attempts: {e}') + + def _words_cleaning(words: List[dict]): words_cleaned: List[dict] = [] for i, w in enumerate(words): diff --git a/backend/utils/stt/safe_socket.py b/backend/utils/stt/safe_socket.py index e157600701..d75f6368fa 100644 --- a/backend/utils/stt/safe_socket.py +++ b/backend/utils/stt/safe_socket.py @@ -4,7 +4,7 @@ can import SafeDeepgramSocket without pulling in GCP/storage dependencies. Architecture: SafeDeepgramSocket is the SOLE keepalive owner for a DG connection. -No other layer (GatedDeepgramSocket, transcribe.py) should call keep_alive() directly. +No other layer (GatedSTTSocket, transcribe.py) should call keep_alive() directly. A background daemon thread sends keepalive when idle > keepalive_interval_sec. """ @@ -14,6 +14,8 @@ from dataclasses import dataclass from typing import Callable, Optional +from utils.stt.socket import STTSocket + logger = logging.getLogger(__name__) @@ -37,7 +39,7 @@ def __post_init__(self): raise ValueError(f'check_period_sec must be > 0, got {self.check_period_sec}') -class SafeDeepgramSocket: +class SafeDeepgramSocket(STTSocket): """Wraps a raw Deepgram LiveConnection with auto-keepalive and dead-connection detection. Auto-keepalive: A background daemon thread sends keepalive when the connection @@ -48,12 +50,10 @@ class SafeDeepgramSocket: Dead detection: Monitors send() and keep_alive() return values. When either returns False or raises, marks connection as permanently dead (one-way latch). - This is the SOLE keepalive owner — GatedDeepgramSocket and orchestrator code + This is the SOLE keepalive owner — GatedSTTSocket and orchestrator code must NOT call keep_alive() directly. """ - _is_safe_dg_socket = True # Marker for duck-type checks (avoids circular import) - def __init__( self, dg_connection, diff --git a/backend/utils/stt/socket.py b/backend/utils/stt/socket.py new file mode 100644 index 0000000000..5507679042 --- /dev/null +++ b/backend/utils/stt/socket.py @@ -0,0 +1,22 @@ +from abc import ABC, abstractmethod +from typing import Optional + + +class STTSocket(ABC): + + @abstractmethod + def send(self, data: bytes) -> None: ... + + @abstractmethod + def finish(self) -> None: ... + + @abstractmethod + def finalize(self) -> None: ... + + @property + @abstractmethod + def is_connection_dead(self) -> bool: ... + + @property + @abstractmethod + def death_reason(self) -> Optional[str]: ... diff --git a/backend/utils/stt/streaming.py b/backend/utils/stt/streaming.py index b45043c318..bdb1a43a51 100644 --- a/backend/utils/stt/streaming.py +++ b/backend/utils/stt/streaming.py @@ -1,6 +1,11 @@ import asyncio +import io +import json import os import random +import threading +import urllib.parse +import wave as _wave from enum import Enum from typing import Callable, List, Optional @@ -10,7 +15,7 @@ from utils.byok import get_byok_key from utils.stt.safe_socket import KeepaliveConfig, SafeDeepgramSocket # noqa: F401 — re-exported for backward compat -from utils.stt.vad_gate import GatedDeepgramSocket +from utils.stt.socket import STTSocket import logging logger = logging.getLogger(__name__) @@ -21,11 +26,14 @@ class STTService(str, Enum): deepgram = "deepgram" + modulate = "modulate" @staticmethod def get_model_name(value): if value == STTService.deepgram: return 'deepgram_streaming' + if value == STTService.modulate: + return 'modulate_streaming' deepgram_nova3_multi_languages = { @@ -139,11 +147,93 @@ def get_model_name(value): } +modulate_languages = { + 'multi', + 'en', + 'af', + 'sq', + 'ar', + 'az', + 'eu', + 'be', + 'bn', + 'bs', + 'bg', + 'ca', + 'zh', + 'hr', + 'cs', + 'da', + 'nl', + 'et', + 'fi', + 'fr', + 'gl', + 'de', + 'el', + 'gu', + 'he', + 'hi', + 'hu', + 'id', + 'it', + 'ja', + 'kn', + 'kk', + 'ko', + 'lv', + 'lt', + 'mk', + 'ms', + 'ml', + 'mr', + 'no', + 'fa', + 'pl', + 'pt', + 'pa', + 'ro', + 'ru', + 'sr', + 'sk', + 'sl', + 'es', + 'sw', + 'sv', + 'tl', + 'ta', + 'te', + 'th', + 'tr', + 'uk', + 'ur', + 'vi', + 'cy', +} + +stt_service_models = os.getenv('STT_SERVICE_MODELS', 'dg-nova-3').split(',') + + +def _normalize_language(language: str) -> str: + if not language: + return '' + return language.split('-')[0].split('_')[0].lower() + + def get_stt_service_for_language(language: str, multi_lang_enabled: bool = True): - if multi_lang_enabled and language in deepgram_nova3_multi_languages: - return STTService.deepgram, 'multi', 'nova-3' - if language in deepgram_nova3_languages: - return STTService.deepgram, language, 'nova-3' + base_lang = _normalize_language(language) + for m in stt_service_models: + m = m.strip() + if m.startswith('dg-'): + dg_model = m.replace('dg-', '', 1) + if multi_lang_enabled and language in deepgram_nova3_multi_languages: + return STTService.deepgram, 'multi', dg_model + if language in deepgram_nova3_languages: + return STTService.deepgram, language, dg_model + continue + if m == 'modulate-velma-2': + if base_lang in modulate_languages: + return STTService.modulate, base_lang, 'velma-2' # Fallback to deepgram nova-3 with English return STTService.deepgram, 'en', 'nova-3' @@ -187,26 +277,10 @@ async def process_audio_dg( channels: int, model: str = 'nova-3', keywords: List[str] = [], - vad_gate=None, is_active: Optional[Callable[[], bool]] = None, ): - """Create a Deepgram streaming connection. - - Args: - vad_gate: Optional VADStreamingGate. If provided, returns a - GatedDeepgramSocket that handles VAD gating internally and - remaps timestamps in the stream_transcript callback. - """ logger.info(f'process_audio_dg {language} {sample_rate} {channels}') - # If gate provided, wrap stream_transcript to remap DG timestamps - if vad_gate is not None: - _original_stream_transcript = stream_transcript - - def stream_transcript(segments): - vad_gate.remap_segments(segments) - _original_stream_transcript(segments) - def on_message(self, result, **kwargs): sentence = result.channel.alternatives[0].transcript if len(sentence) == 0: @@ -271,9 +345,6 @@ def on_dg_error(self, error, **kwargs): dg_connection.on(LiveTranscriptionEvents.Close, on_dg_close) dg_connection.on(LiveTranscriptionEvents.Error, on_dg_error) - # Wrap with VAD gate if provided - if vad_gate is not None: - return GatedDeepgramSocket(safe_conn, gate=vad_gate) return safe_conn @@ -407,3 +478,258 @@ def on_unhandled(self, unhandled, **kwargs): raise Exception(f'Could not open socket: WebSocketException {e}') except Exception as e: raise Exception(f'Could not open socket: {e}') + + +# --------------------------------------------------------------------------- +# Modulate (Velma-2) streaming +# --------------------------------------------------------------------------- + + +def _build_wav_header(sample_rate: int, bits_per_sample: int = 16, channels: int = 1) -> bytes: + buf = io.BytesIO() + with _wave.open(buf, 'wb') as wf: + wf.setnchannels(channels) + wf.setsampwidth(bits_per_sample // 8) + wf.setframerate(sample_rate) + wf.writeframes(b'') + return buf.getvalue() + + +class SafeModulateSocket(STTSocket): + + def __init__(self, ws, stream_transcript, loop, preseconds: int = 0): + self._ws = ws + self._stream_transcript = stream_transcript + self._loop = loop + self._preseconds = preseconds + self._dead = False + self._closed = False + self._death_reason: Optional[str] = None + self._lock = threading.Lock() + self._header_sent = False + self._wav_header: Optional[bytes] = None + self._send_queue: asyncio.Queue = asyncio.Queue(maxsize=2000) + self._done_event = asyncio.Event() + self._prev_partial_text: str = '' + self._prev_partial_start_ms: int = 0 + self._prev_partial_word_count: int = 0 + self._recv_task = asyncio.ensure_future(self._recv_loop(), loop=loop) + self._send_task = asyncio.ensure_future(self._send_loop(), loop=loop) + + def set_wav_header(self, header: bytes): + self._wav_header = header + + @property + def is_connection_dead(self) -> bool: + return self._dead + + @property + def death_reason(self) -> Optional[str]: + return self._death_reason + + def _mark_dead(self, reason: str): + with self._lock: + if not self._dead: + self._dead = True + self._death_reason = reason + + def send(self, data: bytes) -> None: + with self._lock: + if self._dead or self._closed: + return + if not self._header_sent and self._wav_header: + data = self._wav_header + data + self._header_sent = True + + def _enqueue(): + try: + self._send_queue.put_nowait(data) + except asyncio.QueueFull: + self._mark_dead('send queue full') + + try: + self._loop.call_soon_threadsafe(_enqueue) + except RuntimeError: + self._mark_dead('event loop closed') + + def finalize(self) -> None: + pass + + def finish(self) -> None: + with self._lock: + if self._closed: + return + self._closed = True + try: + self._loop.call_soon_threadsafe(lambda: self._send_queue.put_nowait(b'')) + except (RuntimeError, Exception): + pass + + async def drain_and_close(self): + try: + await asyncio.sleep(0) + _EOS_SENTINEL = b'__EOS__' + try: + self._send_queue.put_nowait(_EOS_SENTINEL) + except asyncio.QueueFull: + pass + try: + await asyncio.wait_for(self._send_task, timeout=10) + except (asyncio.TimeoutError, asyncio.CancelledError): + pass + try: + await asyncio.wait_for(self._done_event.wait(), timeout=60) + except (asyncio.TimeoutError, asyncio.CancelledError): + logger.warning('Modulate drain timed out waiting for done message') + except Exception: + pass + self._recv_task.cancel() + try: + await self._ws.close() + except Exception: + pass + + async def _send_loop(self): + _EOS_SENTINEL = b'__EOS__' + try: + while not self._closed and not self._dead: + data = await self._send_queue.get() + if data == b'' or data == _EOS_SENTINEL: + break + await self._ws.send(data) + except websockets.exceptions.ConnectionClosed as e: + self._mark_dead(f'ws send closed: {e}') + except Exception as e: + self._mark_dead(f'ws send error: {e}') + + async def _recv_loop(self): + try: + async for raw_msg in self._ws: + if self._closed: + break + try: + msg = json.loads(raw_msg) + except (json.JSONDecodeError, TypeError): + continue + + msg_type = msg.get('type', '') + if msg_type == 'error': + err = msg.get('error', msg.get('message', 'unknown error')) + logger.error(f'Modulate streaming error: {err}') + if self._prev_partial_text: + self._flush_partial() + self._done_event.set() + self._mark_dead(f'modulate error: {err}') + break + elif msg_type == 'done': + logger.info('Modulate streaming done: duration_ms=%s', msg.get('duration_ms')) + if self._prev_partial_text: + self._flush_partial() + self._done_event.set() + break + elif msg_type == 'partial_utterance': + pu = msg.get('partial_utterance', msg) + self._handle_partial_utterance(pu) + elif msg_type == 'utterance': + utt = msg.get('utterance', msg) + self._handle_utterance(utt) + except websockets.exceptions.ConnectionClosed as e: + self._mark_dead(f'ws recv closed: {e}') + except Exception as e: + self._mark_dead(f'ws recv error: {e}') + + def _handle_partial_utterance(self, msg: dict): + text = msg.get('text', '').strip() + if not text: + return + start_ms = msg.get('start_ms', 0) + self._prev_partial_text = text + self._prev_partial_start_ms = start_ms + self._prev_partial_word_count = len(text.split()) + + def _flush_partial(self): + text = self._prev_partial_text + start_ms = self._prev_partial_start_ms + self._prev_partial_text = '' + self._prev_partial_word_count = 0 + if not text: + return + start = start_ms / 1000.0 + if self._preseconds and start < self._preseconds: + return + segments = [ + { + 'speaker': 'SPEAKER_00', + 'start': start, + 'end': start, + 'text': text, + 'is_user': False, + 'person_id': None, + } + ] + self._stream_transcript(segments) + + def _handle_utterance(self, msg: dict): + text = msg.get('text', '').strip() + if not text: + return + + self._prev_partial_text = '' + self._prev_partial_word_count = 0 + + start_ms = msg.get('start_ms', 0) + duration_ms = msg.get('duration_ms', 0) + start = start_ms / 1000.0 + end = (start_ms + duration_ms) / 1000.0 + + if self._preseconds and start < self._preseconds: + return + + raw_speaker = msg.get('speaker') + if isinstance(raw_speaker, int) and raw_speaker >= 1: + speaker_idx = raw_speaker - 1 + else: + speaker_idx = 0 + speaker = f'SPEAKER_{speaker_idx:02d}' + + segments = [ + { + 'speaker': speaker, + 'start': start, + 'end': end, + 'text': text, + 'is_user': False, + 'person_id': None, + } + ] + self._stream_transcript(segments) + + +async def process_audio_modulate( + stream_transcript, + sample_rate: int, + language: str, + preseconds: int = 0, +): + api_key = os.getenv('MODULATE_API_KEY') + if not api_key: + raise ValueError('MODULATE_API_KEY environment variable is not set') + + params = { + 'api_key': api_key, + 'speaker_diarization': 'true', + 'partial_results': 'true', + 'sample_rate': str(sample_rate), + 'audio_format': 's16le', + 'num_channels': '1', + } + if language and language != 'multi': + params['language'] = language + uri = f'wss://modulate-developer-apis.com/api/velma-2-stt-streaming?{urllib.parse.urlencode(params)}' + + logger.info(f'Connecting to Modulate Velma-2 streaming sample_rate={sample_rate} language={language}') + ws = await websockets.connect(uri, ping_timeout=10, ping_interval=10) + loop = asyncio.get_running_loop() + sock = SafeModulateSocket(ws, stream_transcript, loop, preseconds=preseconds) + logger.info('Modulate Velma-2 streaming connection established') + return sock diff --git a/backend/utils/stt/vad_gate.py b/backend/utils/stt/vad_gate.py index f37422d9c4..0a34faa87b 100644 --- a/backend/utils/stt/vad_gate.py +++ b/backend/utils/stt/vad_gate.py @@ -1,7 +1,7 @@ """ VAD Streaming Gate — Issue #4644 -Server-side VAD gate that skips sending silence to Deepgram, +Server-side VAD gate that skips sending silence to the STT provider, using KeepAlive to maintain the connection and Finalize to flush pending transcripts on speech→silence transitions. @@ -24,6 +24,7 @@ import numpy as np +from utils.stt.socket import STTSocket from utils.stt.vad import _get_ort_session, make_fresh_state, run_vad_window, VAD_WINDOW_SAMPLES logger = logging.getLogger('vad_gate') @@ -65,13 +66,13 @@ class GateOutput: # --------------------------------------------------------------------------- # DG ↔ Wall-clock timestamp mapper # --------------------------------------------------------------------------- -class DgWallMapper: - """Maps DG audio-time timestamps to wall-clock-relative timestamps. +class WallTimeMapper: + """Maps STT provider audio-time timestamps to wall-clock-relative timestamps. - DG timestamps are continuous (only counting audio actually sent). - When we skip silence via KeepAlive, DG time compresses vs wall time. + Provider timestamps are continuous (only counting audio actually sent). + When we skip silence via KeepAlive, provider time compresses vs wall time. This mapper tracks checkpoints at each silence→speech transition to - convert DG timestamps back to wall-clock-relative timestamps. + convert provider timestamps back to wall-clock-relative timestamps. """ _MAX_CHECKPOINTS = 500 # Cap to bound memory for long sessions @@ -80,7 +81,7 @@ def __init__(self): self._lock = threading.Lock() # Each checkpoint: (dg_sec, wall_rel_sec) at silence→speech transition self._checkpoints: List[Tuple[float, float]] = [] - self._dg_cursor_sec: float = 0.0 + self._provider_cursor_sec: float = 0.0 self._sending: bool = False def on_audio_sent(self, chunk_duration_sec: float, chunk_wall_rel_sec: float) -> None: @@ -94,9 +95,9 @@ def on_audio_sent(self, chunk_duration_sec: float, chunk_wall_rel_sec: float) -> # ranges that cause non-monotonic remapped timestamps. if self._checkpoints: prev_dg, prev_wall = self._checkpoints[-1] - min_wall = prev_wall + (self._dg_cursor_sec - prev_dg) + min_wall = prev_wall + (self._provider_cursor_sec - prev_dg) chunk_wall_rel_sec = max(chunk_wall_rel_sec, min_wall) - self._checkpoints.append((self._dg_cursor_sec, chunk_wall_rel_sec)) + self._checkpoints.append((self._provider_cursor_sec, chunk_wall_rel_sec)) # Compact: keep an anchor for early remaps + recent checkpoints. if len(self._checkpoints) > self._MAX_CHECKPOINTS: if self._MAX_CHECKPOINTS <= 1: @@ -104,7 +105,7 @@ def on_audio_sent(self, chunk_duration_sec: float, chunk_wall_rel_sec: float) -> else: self._checkpoints = [self._checkpoints[0]] + self._checkpoints[-(self._MAX_CHECKPOINTS - 1) :] self._sending = True - self._dg_cursor_sec += chunk_duration_sec + self._provider_cursor_sec += chunk_duration_sec def on_silence_skipped(self) -> None: """Called when silence is skipped (not sent to DG).""" @@ -127,7 +128,7 @@ def dg_to_wall_rel(self, dg_sec: float) -> float: # VAD Streaming Gate (per-session) # --------------------------------------------------------------------------- class VADStreamingGate: - """Per-session VAD gate that decides whether to send audio to DG. + """Per-session VAD gate that decides whether to send audio to the STT provider. Uses ONNX Silero-VAD model's speech probability (not start/end events) for robust per-chunk speech detection. Buffers VAD input samples to handle @@ -185,7 +186,7 @@ def __init__( self._pre_roll_total_ms: float = 0.0 # Timestamp mapper - self.dg_wall_mapper = DgWallMapper() + self.dg_wall_mapper = WallTimeMapper() # Metrics self._chunks_total = 0 @@ -206,8 +207,8 @@ def __init__( def activate(self) -> None: """Switch from shadow to active mode (used after speech profile completes). - Advances the DgWallMapper cursor to account for all audio sent during - shadow mode. Without this, the mapper would think DG cursor is at 0 + Advances the WallTimeMapper cursor to account for all audio sent during + shadow mode. Without this, the mapper would think provider cursor is at 0 and over-shift all timestamps after the first gated silence gap. """ if self.mode == 'shadow': @@ -221,7 +222,7 @@ def activate(self) -> None: self._vad_state, self._vad_context = make_fresh_state() self._vad_buffer = np.array([], dtype=np.float32) # Sync mapper cursor: DG received all audio during shadow phase - self.dg_wall_mapper._dg_cursor_sec = self._audio_cursor_ms / 1000.0 + self.dg_wall_mapper._provider_cursor_sec = self._audio_cursor_ms / 1000.0 logger.info( 'VADGate activated shadow->active uid=%s session=%s cursor=%.1fms', self.uid, @@ -230,7 +231,7 @@ def activate(self) -> None: ) def needs_keepalive(self, wall_time: float) -> bool: - """Check if a keepalive should be sent to prevent DG timeout.""" + """Check if a keepalive should be sent to prevent STT provider timeout.""" if self.mode != 'active': return False ref_time = self._last_send_wall_time or self._first_audio_wall_time @@ -526,7 +527,7 @@ def to_json_log(self) -> dict: } def remap_segments(self, segments: list) -> None: - """Remap DG timestamps to wall-clock-relative if gate is active.""" + """Remap STT provider timestamps to wall-clock-relative if gate is active.""" if self.mode == 'active': for seg in segments: seg['start'] = self.dg_wall_mapper.dg_to_wall_rel(seg['start']) @@ -539,13 +540,13 @@ def record_keepalive(self, wall_time: float) -> None: # --------------------------------------------------------------------------- -# Gated Deepgram Socket — wraps raw DG connection with VAD gate +# Gated STT Socket — wraps any STTSocket with VAD gate # --------------------------------------------------------------------------- -class GatedDeepgramSocket: - """Wraps a Deepgram LiveConnection with built-in VAD gate. +class GatedSTTSocket(STTSocket): + """Wraps an STTSocket with built-in VAD gate. When gate is active: - - send() runs VAD internally, only forwards speech audio to DG + - send() runs VAD internally, only forwards speech audio to the STT provider - Automatically calls finalize() on speech→silence transitions - finish() flushes pending transcript before closing When gate is None or mode='shadow': @@ -554,9 +555,12 @@ class GatedDeepgramSocket: This keeps all VAD logic out of transcribe.py. """ - def __init__(self, dg_connection, gate: Optional['VADStreamingGate'] = None): - self._conn = dg_connection + def __init__( + self, stt_connection: STTSocket, gate: Optional['VADStreamingGate'] = None, passthrough_audio: bool = False + ): + self._conn = stt_connection self._gate = gate + self._passthrough_audio = passthrough_audio # Audio capture for transcript quality validation (off by default) self._capture_dir = os.getenv('VAD_GATE_AUDIO_CAPTURE_DIR', '') self._raw_file = None @@ -569,18 +573,16 @@ def __init__(self, dg_connection, gate: Optional['VADStreamingGate'] = None): @property def is_connection_dead(self) -> bool: - """True if DG connection has been detected as dead. Delegates to SafeDeepgramSocket.""" - if getattr(self._conn, '_is_safe_dg_socket', None) is True: + if isinstance(self._conn, STTSocket): return self._conn.is_connection_dead return False @property def death_reason(self) -> Optional[str]: - """Why the DG connection died. Delegates to SafeDeepgramSocket.""" return self._conn.death_reason def send(self, data: bytes, wall_time: Optional[float] = None) -> None: - """Send audio through VAD gate (if active), then to DG.""" + """Send audio through VAD gate (if active), then to the STT provider.""" if self.is_connection_dead: return if self._gate is None: @@ -598,11 +600,10 @@ def send(self, data: bytes, wall_time: Optional[float] = None) -> None: self._raw_file.write(data) if self._gated_file and gate_out.audio_to_send: self._gated_file.write(gate_out.audio_to_send) - if gate_out.audio_to_send: - # SafeDeepgramSocket.send() handles dead detection internally + if self._passthrough_audio: + self._conn.send(data) + elif gate_out.audio_to_send: self._conn.send(gate_out.audio_to_send) - # Keepalive is handled automatically by SafeDeepgramSocket's background thread. - # No explicit keep_alive() call needed here (#5870 architecture). if gate_out.should_finalize: try: self._conn.finalize() @@ -615,7 +616,7 @@ def finalize(self) -> None: self._conn.finalize() def finish(self) -> None: - """Close DG connection. Flushes first if gate is active.""" + """Close STT connection. Flushes first if gate is active.""" if self._gate is not None and self._gate.mode == 'active': try: self._conn.finalize() @@ -631,7 +632,7 @@ def finish(self) -> None: pass def remap_segments(self, segments: list) -> None: - """Remap DG timestamps from audio-time to wall-clock-relative time.""" + """Remap STT provider timestamps from audio-time to wall-clock-relative time.""" if self._gate is not None: self._gate.remap_segments(segments) @@ -644,3 +645,8 @@ def get_metrics(self) -> Optional[dict]: @property def is_gated(self) -> bool: return self._gate is not None + + +# Backward-compatibility aliases +GatedDeepgramSocket = GatedSTTSocket +DgWallMapper = WallTimeMapper