diff --git a/README.md b/README.md index 09436b0..a6c65c0 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,26 @@ whispers completions zsh - `whispers` installs the helper rewrite worker for you when that feature is enabled. - Shell completions are printed to `stdout`. +## Troubleshooting + +If the main `whispers` process ever gets stuck after playback when using local +`whisper_cpp`, enable the built-in hang diagnostics for the next repro: + +```sh +WHISPERS_HANG_DEBUG=1 whispers +``` + +When that mode is enabled, `whispers` writes runtime status and hang bundles +under `${XDG_RUNTIME_DIR:-/tmp}/whispers/`: + +- `main-status.json` shows the current dictation stage and recent stage metadata. +- `hang---.log` is emitted if `whisper_cpp` spends too + long in model load or transcription. + +Those bundles include the current status snapshot plus best-effort stack and +open-file diagnostics. If the hang reproduces, capture the newest `hang-*.log` +file along with `main-status.json`. + ## License [Mozilla Public License 2.0](LICENSE) diff --git a/src/agentic_rewrite/mod.rs b/src/agentic_rewrite/mod.rs index 0b6d394..07cf0ae 100644 --- a/src/agentic_rewrite/mod.rs +++ b/src/agentic_rewrite/mod.rs @@ -46,6 +46,12 @@ struct PreparedGlossaryEntry { normalized_aliases: Vec>, } +#[derive(Debug, Clone, Default)] +pub(crate) struct RuntimePolicyResources { + policy_rules: Vec, + glossary_entries: Vec, +} + pub use runtime::conservative_output_allowed; pub fn default_policy_path() -> &'static str { @@ -56,17 +62,43 @@ pub fn default_glossary_path() -> &'static str { store::default_glossary_path() } +pub(crate) fn load_runtime_resources(config: &Config) -> RuntimePolicyResources { + load_runtime_resources_with_status(config).0 +} + +pub(crate) fn load_runtime_resources_with_status( + config: &Config, +) -> (RuntimePolicyResources, bool) { + let (policy_rules, policy_degraded) = + store::load_policy_file_for_runtime_with_status(&config.resolved_rewrite_policy_path()); + let (glossary_entries, glossary_degraded) = + store::load_glossary_file_for_runtime_with_status(&config.resolved_rewrite_glossary_path()); + + ( + RuntimePolicyResources { + policy_rules, + glossary_entries, + }, + policy_degraded || glossary_degraded, + ) +} + pub fn apply_runtime_policy(config: &Config, transcript: &mut RewriteTranscript) { - let policy_rules = store::load_policy_file_for_runtime(&config.resolved_rewrite_policy_path()); - let glossary_entries = - store::load_glossary_file_for_runtime(&config.resolved_rewrite_glossary_path()); + let resources = load_runtime_resources(config); + apply_runtime_policy_with_resources(config, transcript, &resources); +} +pub(crate) fn apply_runtime_policy_with_resources( + config: &Config, + transcript: &mut RewriteTranscript, + resources: &RuntimePolicyResources, +) { let policy_context = runtime::resolve_policy_context( config.rewrite.default_correction_policy, transcript.typing_context.as_ref(), &transcript.rewrite_candidates, - &policy_rules, - &glossary_entries, + &resources.policy_rules, + &resources.glossary_entries, ); for candidate in &policy_context.glossary_candidates { diff --git a/src/agentic_rewrite/runtime.rs b/src/agentic_rewrite/runtime.rs index d3bcdc6..7a8e26b 100644 --- a/src/agentic_rewrite/runtime.rs +++ b/src/agentic_rewrite/runtime.rs @@ -305,7 +305,7 @@ fn built_in_rules(default_policy: RewriteCorrectionPolicy) -> Vec { surface_kind: Some(RewriteSurfaceKind::Browser), ..ContextMatcher::default() }, - "Favor clean prose and natural punctuation for browser text fields, but stay grounded in the listed candidates, glossary evidence, and the utterance's technical topic when it clearly refers to software or documentation.", + "Favor clean prose and natural punctuation for browser text fields, except when the utterance is structured text such as a hostname, URL, email address, or similar punctuation-sensitive literal. In those cases preserve punctuation literally and do not rewrite it into prose. Stay grounded in the listed candidates, glossary evidence, and the utterance's technical topic when it clearly refers to software or documentation.", Some(RewriteCorrectionPolicy::Balanced), ), AppRule::built_in( @@ -314,7 +314,7 @@ fn built_in_rules(default_policy: RewriteCorrectionPolicy) -> Vec { surface_kind: Some(RewriteSurfaceKind::GenericText), ..ContextMatcher::default() }, - "Favor clean prose and natural punctuation for general text entry while staying grounded in the listed candidates and glossary evidence. If the utterance clearly discusses technical tools or software, prefer the most plausible technical term over a phonetically similar common word.", + "Favor clean prose and natural punctuation for general text entry, except when the utterance is structured text such as a hostname, URL, email address, or similar punctuation-sensitive literal. In those cases preserve punctuation literally and do not rewrite it into prose. Stay grounded in the listed candidates and glossary evidence. If the utterance clearly discusses technical tools or software, prefer the most plausible technical term over a phonetically similar common word.", Some(RewriteCorrectionPolicy::Balanced), ), AppRule::built_in( diff --git a/src/agentic_rewrite/store.rs b/src/agentic_rewrite/store.rs index a54c0b3..e785de7 100644 --- a/src/agentic_rewrite/store.rs +++ b/src/agentic_rewrite/store.rs @@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize}; use crate::config::Config; use crate::error::{Result, WhsprError}; +use crate::safe_fs; use super::{AppRule, GlossaryEntry}; @@ -101,7 +102,7 @@ fn ensure_text_file(path: &Path, contents: &str) -> Result { } write_parent(path)?; - std::fs::write(path, contents).map_err(|e| { + safe_fs::write(path, contents).map_err(|e| { WhsprError::Config(format!( "failed to write starter file {}: {e}", path.display() @@ -115,7 +116,7 @@ pub(super) fn read_policy_file(path: &Path) -> Result> { return Ok(Vec::new()); } - let contents = std::fs::read_to_string(path).map_err(|e| { + let contents = safe_fs::read_to_string(path).map_err(|e| { WhsprError::Config(format!("failed to read app rules {}: {e}", path.display())) })?; if contents.trim().is_empty() { @@ -133,7 +134,7 @@ pub(super) fn write_policy_file(path: &Path, rules: &[AppRule]) -> Result<()> { rules: rules.to_vec(), }) .map_err(|e| WhsprError::Config(format!("failed to encode app rules: {e}")))?; - std::fs::write(path, contents).map_err(|e| { + safe_fs::write(path, contents).map_err(|e| { WhsprError::Config(format!("failed to write app rules {}: {e}", path.display())) })?; Ok(()) @@ -144,7 +145,7 @@ pub(super) fn read_glossary_file(path: &Path) -> Result> { return Ok(Vec::new()); } - let contents = std::fs::read_to_string(path).map_err(|e| { + let contents = safe_fs::read_to_string(path).map_err(|e| { WhsprError::Config(format!("failed to read glossary {}: {e}", path.display())) })?; if contents.trim().is_empty() { @@ -162,28 +163,30 @@ pub(super) fn write_glossary_file(path: &Path, entries: &[GlossaryEntry]) -> Res entries: entries.to_vec(), }) .map_err(|e| WhsprError::Config(format!("failed to encode glossary: {e}")))?; - std::fs::write(path, contents).map_err(|e| { + safe_fs::write(path, contents).map_err(|e| { WhsprError::Config(format!("failed to write glossary {}: {e}", path.display())) })?; Ok(()) } -pub(super) fn load_policy_file_for_runtime(path: &Path) -> Vec { +pub(super) fn load_policy_file_for_runtime_with_status(path: &Path) -> (Vec, bool) { match read_policy_file(path) { - Ok(rules) => rules, + Ok(rules) => (rules, false), Err(err) => { tracing::warn!("{err}; using built-in app rewrite defaults"); - Vec::new() + (Vec::new(), true) } } } -pub(super) fn load_glossary_file_for_runtime(path: &Path) -> Vec { +pub(super) fn load_glossary_file_for_runtime_with_status( + path: &Path, +) -> (Vec, bool) { match read_glossary_file(path) { - Ok(entries) => entries, + Ok(entries) => (entries, false), Err(err) => { tracing::warn!("{err}; ignoring runtime glossary"); - Vec::new() + (Vec::new(), true) } } } diff --git a/src/app/mod.rs b/src/app/mod.rs index 8a58ae6..858d883 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -3,19 +3,22 @@ use std::time::Instant; use crate::config::Config; use crate::error::Result; use crate::postprocess::finalize; +use crate::runtime_diagnostics::DictationRuntimeDiagnostics; +use crate::runtime_support::PidLock; mod osd; mod runtime; use runtime::DictationRuntime; -pub async fn run(config: Config) -> Result<()> { +pub async fn run(config: Config, _pid_lock: PidLock) -> Result<()> { let activation_started = Instant::now(); + let diagnostics = DictationRuntimeDiagnostics::new(&config); // Register signals before startup work to minimize early-signal races. let mut sigusr1 = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::user_defined1())?; let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?; - let mut runtime = DictationRuntime::new(config); + let mut runtime = DictationRuntime::new(config, diagnostics.clone()); let recording = runtime.start_recording()?; runtime.prepare_services()?; @@ -41,6 +44,7 @@ pub async fn run(config: Config) -> Result<()> { if transcribed.is_empty() { tracing::warn!("transcription returned empty text"); finalize::wait_for_feedback_drain().await; + diagnostics.clear_with_stage(crate::runtime_diagnostics::DictationStage::Done); return Ok(()); } @@ -53,10 +57,12 @@ pub async fn run(config: Config) -> Result<()> { // causes an audible click as the OS closes our audio file descriptors. // With speech, transcription takes seconds — providing natural drain time. finalize::wait_for_feedback_drain().await; + diagnostics.clear_with_stage(crate::runtime_diagnostics::DictationStage::Done); return Ok(()); } runtime.inject_finalized(finalized).await?; + diagnostics.clear_with_stage(crate::runtime_diagnostics::DictationStage::Done); tracing::info!("done"); tracing::info!( diff --git a/src/app/osd.rs b/src/app/osd.rs index d6a4af2..5821c6b 100644 --- a/src/app/osd.rs +++ b/src/app/osd.rs @@ -1,8 +1,13 @@ use std::process::Child; +use std::time::Duration; #[cfg(feature = "osd")] use std::process::Command; +use crate::runtime_guards::wait_child_with_timeout; + +const OSD_SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(250); + #[cfg(feature = "osd")] pub(super) fn spawn_osd() -> Option { // Look for whispers-osd next to our own binary first, then fall back to PATH @@ -38,8 +43,21 @@ pub(super) fn kill_osd(child: &mut Option) { unsafe { libc::kill(pid, libc::SIGTERM); } - let _ = c.wait(); - tracing::debug!("whispers-osd (pid {pid}) terminated"); + match wait_child_with_timeout(&mut c, OSD_SHUTDOWN_TIMEOUT) { + Ok(Some(_)) => { + tracing::debug!("whispers-osd (pid {pid}) terminated"); + } + Ok(None) => { + unsafe { + libc::kill(pid, libc::SIGKILL); + } + let _ = wait_child_with_timeout(&mut c, OSD_SHUTDOWN_TIMEOUT); + tracing::warn!("whispers-osd (pid {pid}) did not exit after SIGTERM; killed"); + } + Err(err) => { + tracing::warn!("failed to wait for whispers-osd (pid {pid}) to exit: {err}"); + } + } } } diff --git a/src/app/runtime.rs b/src/app/runtime.rs index ba138b5..48ed6ea 100644 --- a/src/app/runtime.rs +++ b/src/app/runtime.rs @@ -8,15 +8,21 @@ use crate::context::{self, TypingContext}; use crate::error::Result; use crate::feedback::FeedbackPlayer; use crate::inject::TextInjector; -use crate::postprocess::{execution, finalize}; +use crate::postprocess::{execution, finalize, planning}; use crate::rewrite_worker::RewriteService; +use crate::runtime_diagnostics::{ + DictationRuntimeDiagnostics, DictationStage, DictationStageMetadata, +}; use crate::session::{self, EligibleSessionEntry}; use crate::transcribe::Transcript; pub(super) struct DictationRuntime { config: Config, + diagnostics: DictationRuntimeDiagnostics, feedback: FeedbackPlayer, session_enabled: bool, + runtime_text_resources: planning::RuntimeTextResources, + runtime_text_resources_degraded_reason: Option, transcriber: Option, rewrite_service: Option, } @@ -44,18 +50,23 @@ pub(super) struct ReadyInjection { } impl DictationRuntime { - pub(super) fn new(config: Config) -> Self { + pub(super) fn new(config: Config, diagnostics: DictationRuntimeDiagnostics) -> Self { let feedback = FeedbackPlayer::new( config.feedback.enabled, &config.feedback.start_sound, &config.feedback.stop_sound, ); let session_enabled = config.postprocess.mode.uses_rewrite(); + let (runtime_text_resources, runtime_text_resources_degraded_reason) = + load_runtime_text_resources_or_default(&config); Self { config, + diagnostics, feedback, session_enabled, + runtime_text_resources, + runtime_text_resources_degraded_reason, transcriber: None, rewrite_service: None, } @@ -67,7 +78,7 @@ impl DictationRuntime { self.feedback.play_start(); let recording_context = context::capture_typing_context(); let recent_session = if self.session_enabled { - session::load_recent_entry(&self.config.session, &recording_context)? + load_recent_session_entry_guarded(&self.config.session, &recording_context) } else { None }; @@ -75,6 +86,8 @@ impl DictationRuntime { let mut recorder = AudioRecorder::new(&self.config.audio); recorder.start()?; let osd = super::osd::spawn_osd(); + self.diagnostics + .enter_stage(DictationStage::Recording, DictationStageMetadata::default()); tracing::info!("recording... (run whispers again to stop)"); Ok(ActiveRecording { @@ -85,6 +98,10 @@ impl DictationRuntime { } pub(super) fn prepare_services(&mut self) -> Result<()> { + self.diagnostics.enter_stage( + DictationStage::AsrPrepare, + DictationStageMetadata::default(), + ); let transcriber = asr::prepare::prepare_transcriber(&self.config)?; let rewrite_service = execution::prepare_rewrite_service(&self.config); asr::prepare::prewarm_transcriber(&transcriber, "recording"); @@ -100,6 +117,7 @@ impl DictationRuntime { pub(super) fn cancel_recording(&self, mut recording: ActiveRecording) -> Result<()> { super::osd::kill_osd(&mut recording.osd); recording.recorder.stop()?; + self.diagnostics.clear_with_stage(DictationStage::Cancelled); Ok(()) } @@ -121,6 +139,15 @@ impl DictationRuntime { audio_duration_ms, "transcribing captured audio" ); + self.diagnostics.enter_stage( + DictationStage::RecordingStopped, + DictationStageMetadata { + audio_samples: Some(audio.len()), + sample_rate: Some(sample_rate), + audio_duration_ms: Some(audio_duration_ms), + ..DictationStageMetadata::default() + }, + ); Ok(CapturedRecording { audio, @@ -133,6 +160,19 @@ impl DictationRuntime { &mut self, recording: CapturedRecording, ) -> Result { + let audio_samples = recording.audio.len(); + let sample_rate = recording.sample_rate; + let audio_duration_ms = + ((audio_samples as f64 / sample_rate as f64) * 1000.0).round() as u64; + self.diagnostics.enter_stage( + DictationStage::AsrTranscribe, + DictationStageMetadata { + audio_samples: Some(audio_samples), + sample_rate: Some(sample_rate), + audio_duration_ms: Some(audio_duration_ms), + ..DictationStageMetadata::default() + }, + ); let transcriber = self .transcriber .take() @@ -176,10 +216,21 @@ impl DictationRuntime { }); let finalize_started = Instant::now(); + self.diagnostics.enter_stage( + DictationStage::Postprocess, + DictationStageMetadata { + substage: Some("planning".into()), + detail: Some("build_rewrite_plan".into()), + transcript_chars: Some(recording.transcript.raw_text.len()), + degraded_reason: self.runtime_text_resources_degraded_reason.clone(), + ..DictationStageMetadata::default() + }, + ); let finalized = finalize::finalize_transcript( &self.config, recording.transcript, self.rewrite_service.as_ref(), + Some(&self.runtime_text_resources), Some(&injection_context), recent_session.as_ref(), ) @@ -211,20 +262,47 @@ impl DictationRuntime { text, operation, rewrite_summary, + degraded_reason, } = finalized; tracing::info!("injecting text: {:?}", text); + let stage_metadata = DictationStageMetadata { + substage: Some("inject".to_string()), + detail: Some(match operation { + finalize::FinalizedOperation::Append => "clipboard_paste".to_string(), + finalize::FinalizedOperation::ReplaceLastEntry { .. } => { + "replace_recent_text".to_string() + } + }), + output_chars: Some(text.len()), + operation: Some(match operation { + finalize::FinalizedOperation::Append => "append".to_string(), + finalize::FinalizedOperation::ReplaceLastEntry { .. } => { + "replace_last_entry".to_string() + } + }), + rewrite_used: Some(rewrite_summary.rewrite_used), + degraded_reason: degraded_reason.clone(), + ..DictationStageMetadata::default() + }; + self.diagnostics + .enter_stage(DictationStage::Inject, stage_metadata.clone()); let injector = TextInjector::new(); match operation { finalize::FinalizedOperation::Append => { injector.inject(&text).await?; if self.session_enabled { - session::record_append( + let mut session_metadata = stage_metadata.clone(); + session_metadata.substage = Some("session_write".into()); + session_metadata.detail = Some("record_append".into()); + self.diagnostics + .enter_stage(DictationStage::SessionWrite, session_metadata); + record_session_append_guarded( &self.config.session, &injection_context, &text, rewrite_summary, - )?; + ); } } finalize::FinalizedOperation::ReplaceLastEntry { @@ -235,13 +313,18 @@ impl DictationRuntime { .replace_recent_text(delete_graphemes, &text) .await?; if self.session_enabled { - session::record_replace( + let mut session_metadata = stage_metadata; + session_metadata.substage = Some("session_write".into()); + session_metadata.detail = Some("record_replace".into()); + self.diagnostics + .enter_stage(DictationStage::SessionWrite, session_metadata); + record_session_replace_guarded( &self.config.session, &injection_context, entry_id, &text, rewrite_summary, - )?; + ); } } } @@ -250,6 +333,59 @@ impl DictationRuntime { } } +fn load_runtime_text_resources_or_default( + config: &Config, +) -> (planning::RuntimeTextResources, Option) { + let (resources, degraded) = planning::load_runtime_text_resources_with_status(config); + if degraded { + (resources, Some("runtime_text_resources_unavailable".into())) + } else { + (resources, None) + } +} + +fn load_recent_session_entry_guarded( + config: &crate::config::SessionConfig, + context: &TypingContext, +) -> Option { + match session::load_recent_entry(config, context) { + Ok(entry) => entry, + Err(err) => { + tracing::warn!("failed to load recent session entry: {err}; continuing without it"); + None + } + } +} + +fn record_session_append_guarded( + config: &crate::config::SessionConfig, + context: &TypingContext, + text: &str, + rewrite_summary: crate::session::SessionRewriteSummary, +) { + match session::record_append(config, context, text, rewrite_summary) { + Ok(()) => {} + Err(err) => { + tracing::warn!("failed to persist session append: {err}; continuing"); + } + } +} + +fn record_session_replace_guarded( + config: &crate::config::SessionConfig, + context: &TypingContext, + entry_id: u64, + text: &str, + rewrite_summary: crate::session::SessionRewriteSummary, +) { + match session::record_replace(config, context, entry_id, text, rewrite_summary) { + Ok(()) => {} + Err(err) => { + tracing::warn!("failed to persist session replacement: {err}; continuing"); + } + } +} + impl TranscribedRecording { pub(super) fn is_empty(&self) -> bool { self.transcript.is_empty() @@ -261,3 +397,105 @@ impl ReadyInjection { self.finalized.text.is_empty() } } + +#[cfg(test)] +mod tests { + use std::ffi::CString; + use std::os::unix::ffi::OsStrExt; + use std::path::Path; + + use super::*; + use crate::context::SurfaceKind; + use crate::session::SessionRewriteSummary; + use crate::test_support::{EnvVarGuard, env_lock, set_env, unique_temp_dir}; + + fn typing_context() -> TypingContext { + TypingContext { + focus_fingerprint: "niri:7".into(), + app_id: Some("kitty".into()), + window_title: Some("shell".into()), + surface_kind: SurfaceKind::Terminal, + browser_domain: None, + captured_at_ms: 42, + } + } + + fn with_runtime_dir(f: impl FnOnce(&Path) -> T) -> T { + let _env_lock = env_lock(); + let _guard = EnvVarGuard::capture(&["XDG_RUNTIME_DIR"]); + let runtime_dir = unique_temp_dir("runtime-session-timeout"); + set_env( + "XDG_RUNTIME_DIR", + runtime_dir.to_str().expect("runtime dir should be utf-8"), + ); + f(&runtime_dir) + } + + fn mkfifo(path: &Path) { + let c_path = CString::new(path.as_os_str().as_bytes()).expect("fifo path"); + let result = unsafe { libc::mkfifo(c_path.as_ptr(), 0o600) }; + assert_eq!( + result, + 0, + "mkfifo failed: {}", + std::io::Error::last_os_error() + ); + } + + #[test] + fn load_recent_session_entry_skips_blocking_fifo() { + with_runtime_dir(|runtime_dir| { + let session_dir = runtime_dir.join("whispers"); + std::fs::create_dir_all(&session_dir).expect("session dir"); + mkfifo(&session_dir.join("session.json")); + + let recent = load_recent_session_entry_guarded( + &crate::config::SessionConfig::default(), + &typing_context(), + ); + + assert!(recent.is_none()); + }); + } + + #[test] + fn record_session_append_skips_blocking_fifo() { + with_runtime_dir(|runtime_dir| { + let session_dir = runtime_dir.join("whispers"); + std::fs::create_dir_all(&session_dir).expect("session dir"); + mkfifo(&session_dir.join("session.json")); + + record_session_append_guarded( + &crate::config::SessionConfig::default(), + &typing_context(), + "hello", + SessionRewriteSummary { + had_edit_cues: false, + rewrite_used: false, + recommended_candidate: None, + }, + ); + }); + } + + #[test] + fn load_runtime_text_resources_falls_back_to_defaults_for_blocking_fifo() { + let mut config = Config::default(); + config.postprocess.mode = crate::config::PostprocessMode::Rewrite; + with_runtime_dir(|runtime_dir| { + let dictionary_path = runtime_dir.join("blocking-dictionary.toml"); + mkfifo(&dictionary_path); + config.personalization.dictionary_path = dictionary_path + .to_str() + .expect("dictionary path") + .to_string(); + + let (_, degraded_reason) = load_runtime_text_resources_or_default(&config); + + assert_eq!( + degraded_reason.as_deref(), + Some("runtime_text_resources_unavailable") + ); + }); + } +} diff --git a/src/asr/cleanup.rs b/src/asr/cleanup.rs index 92f6948..c577e2c 100644 --- a/src/asr/cleanup.rs +++ b/src/asr/cleanup.rs @@ -134,39 +134,59 @@ fn asr_runtime_scope_dir() -> PathBuf { #[cfg(test)] mod tests { use super::*; + use crate::test_support::{EnvVarGuard, env_lock, set_env, unique_temp_dir}; + + fn with_runtime_dir(f: impl FnOnce(PathBuf) -> T) -> T { + let _env_lock = env_lock(); + let _guard = EnvVarGuard::capture(&["XDG_RUNTIME_DIR"]); + let runtime_dir = unique_temp_dir("asr-cleanup-runtime"); + set_env( + "XDG_RUNTIME_DIR", + runtime_dir + .to_str() + .expect("runtime dir should be valid UTF-8"), + ); + f(runtime_dir.join("whispers")) + } #[test] fn parse_faster_worker_cmdline_extracts_socket_path() { - let socket = asr_runtime_scope_dir().join("asr-faster-123.sock"); - let cmdline = format!( - "/home/user/.local/share/whispers/faster-whisper/venv/bin/python\0/home/user/.local/share/whispers/faster-whisper/faster_whisper_worker.py\0serve\0--socket-path\0{}\0--model-dir\0/tmp/model\0", - socket.display() - ); - let parsed = parse_asr_worker_cmdline(cmdline.as_bytes()).expect("parse worker"); - assert_eq!(parsed.0, "faster_whisper"); - assert_eq!(parsed.1, socket); + with_runtime_dir(|runtime_scope| { + let socket = runtime_scope.join("asr-faster-123.sock"); + let cmdline = format!( + "/home/user/.local/share/whispers/faster-whisper/venv/bin/python\0/home/user/.local/share/whispers/faster-whisper/faster_whisper_worker.py\0serve\0--socket-path\0{}\0--model-dir\0/tmp/model\0", + socket.display() + ); + let parsed = parse_asr_worker_cmdline(cmdline.as_bytes()).expect("parse worker"); + assert_eq!(parsed.0, "faster_whisper"); + assert_eq!(parsed.1, socket); + }); } #[test] fn parse_nemo_worker_cmdline_extracts_socket_path() { - let socket = asr_runtime_scope_dir().join("asr-nemo-456.sock"); - let cmdline = format!( - "/home/user/.local/share/whispers/nemo/venv-asr/bin/python\0/home/user/.local/share/whispers/nemo/nemo_asr_worker.py\0serve\0--socket-path\0{}\0--model-ref\0/tmp/model.nemo\0", - socket.display() - ); - let parsed = parse_asr_worker_cmdline(cmdline.as_bytes()).expect("parse worker"); - assert_eq!(parsed.0, "nemo"); - assert_eq!(parsed.1, socket); + with_runtime_dir(|runtime_scope| { + let socket = runtime_scope.join("asr-nemo-456.sock"); + let cmdline = format!( + "/home/user/.local/share/whispers/nemo/venv-asr/bin/python\0/home/user/.local/share/whispers/nemo/nemo_asr_worker.py\0serve\0--socket-path\0{}\0--model-ref\0/tmp/model.nemo\0", + socket.display() + ); + let parsed = parse_asr_worker_cmdline(cmdline.as_bytes()).expect("parse worker"); + assert_eq!(parsed.0, "nemo"); + assert_eq!(parsed.1, socket); + }); } #[test] fn parse_asr_worker_cmdline_ignores_unrelated_processes() { - let socket = asr_runtime_scope_dir().join("asr-other.sock"); - let cmdline = format!( - "/usr/bin/python\0/home/user/script.py\0serve\0--socket-path\0{}\0", - socket.display() - ); - assert!(parse_asr_worker_cmdline(cmdline.as_bytes()).is_none()); + with_runtime_dir(|runtime_scope| { + let socket = runtime_scope.join("asr-other.sock"); + let cmdline = format!( + "/usr/bin/python\0/home/user/script.py\0serve\0--socket-path\0{}\0", + socket.display() + ); + assert!(parse_asr_worker_cmdline(cmdline.as_bytes()).is_none()); + }); } #[test] diff --git a/src/bin/whispers-rewrite-worker/prompt.rs b/src/bin/whispers-rewrite-worker/prompt.rs index 9f2b751..887c477 100644 --- a/src/bin/whispers-rewrite-worker/prompt.rs +++ b/src/bin/whispers-rewrite-worker/prompt.rs @@ -141,7 +141,9 @@ word. Use nearby category words like window manager, editor, language, library, tool to disambiguate technical names. When a dictated word is an obvious phonetic near-miss for a likely technical term \ and the surrounding context clearly identifies the category, correct it to the canonical technical spelling instead of \ echoing the miss. If multiple plausible interpretations remain similarly credible, stay close to the transcript rather \ -than inventing a niche term. \ +than inventing a niche term. When the utterance is a hostname, URL, email address, or other structured text, preserve \ +dots, slashes, colons, dashes, underscores, and at-signs literally. Do not turn structured labels into sentence \ +punctuation and do not append explanatory prose such as saying something is a URL. \ If an edit intent says to replace or cancel previous wording, preserve that edit when the utterance or same-session \ context clearly supports it. Preserve utterance-initial courtesy or apology wording when the raw transcript still \ clearly intends it. Examples:\n\ @@ -175,7 +177,12 @@ phonetically similar common word. Use nearby category words like window manager, manager, shell, or terminal tool to disambiguate technical names. When a dictated word is an obvious phonetic near-miss \ for a likely technical term and the surrounding context clearly identifies the category, correct it to the canonical \ technical spelling instead of echoing the miss. If multiple plausible interpretations remain similarly credible, stay \ -close to the transcript rather than inventing a niche term. If an edit intent says to replace or cancel previous wording, preserve that edit when the utterance or same-session context clearly supports it. Preserve utterance-initial courtesy or apology wording when the raw transcript still clearly intends it. Examples:\n\ +close to the transcript rather than inventing a niche term. When the utterance is a hostname, URL, email address, or \ +other structured text, preserve dots, slashes, colons, dashes, underscores, and at-signs literally. Do not turn \ +structured labels into sentence punctuation and do not append explanatory prose such as saying something is a URL. If \ +an edit intent says to replace or cancel previous wording, preserve that edit when the utterance or same-session \ +context clearly supports it. Preserve utterance-initial courtesy or apology wording when the raw transcript still \ +clearly intends it. Examples:\n\ - raw: Hello there. Scratch that. Hi.\n correction-aware: Hi.\n final: Hi.\n\ - raw: I'll bring cookies, scratch that, brownies.\n correction-aware: I'll bring brownies.\n final: I'll bring brownies.\n\ - raw: My name is Notes, scratch that my name is Jonatan.\n correction-aware: My my name is Jonatan.\n aggressive correction-aware: My name is Jonatan.\n final: My name is Jonatan.\n\ @@ -320,7 +327,8 @@ Structured cue context:\n\ Self-corrections were already resolved before rewriting.\n\ Use this correction-aware transcript as the main source text. In agentic mode, you may still normalize likely \ technical terms or proper names when the utterance strongly supports them, even if the exact canonical spelling is not \ -already present in the candidate list:\n\ +already present in the candidate list. When a structured-text candidate is present, preserve its punctuation literally \ +and do not rewrite it into prose:\n\ {correction_aware}\n\ {agentic_candidates}\ Do not restore any canceled wording from earlier in the utterance.\n\ @@ -342,7 +350,8 @@ Correction-aware transcript:\n\ {correction_aware}\n\ Treat the correction-aware transcript as authoritative for explicit spoken edits and overall meaning, but in agentic \ mode you may normalize likely technical terms or proper names when category cues in the utterance make the intended \ -technical meaning clearly better than the literal transcript.\n\ +technical meaning clearly better than the literal transcript. When a structured-text candidate is present, preserve \ +its punctuation literally and do not rewrite it into prose.\n\ {agentic_candidates}\ \ Recent segments:\n\ @@ -555,6 +564,9 @@ fn render_rewrite_candidates(transcript: &RewriteTranscript) -> String { crate::rewrite_protocol::RewriteCandidateKind::Literal => { "literal (keep only if the cue was not actually an edit)" } + crate::rewrite_protocol::RewriteCandidateKind::StructuredLiteral => { + "structured_literal (preserve structured punctuation literally)" + } crate::rewrite_protocol::RewriteCandidateKind::ConservativeCorrection => { "conservative_correction (balanced cleanup)" } diff --git a/src/bin/whispers-rewrite-worker/rewrite_protocol.rs b/src/bin/whispers-rewrite-worker/rewrite_protocol.rs index 33b13ff..cbd05b8 100644 --- a/src/bin/whispers-rewrite-worker/rewrite_protocol.rs +++ b/src/bin/whispers-rewrite-worker/rewrite_protocol.rs @@ -192,6 +192,7 @@ pub enum RewriteCorrectionPolicy { #[serde(rename_all = "snake_case")] pub enum RewriteCandidateKind { Literal, + StructuredLiteral, ConservativeCorrection, AggressiveCorrection, GlossaryCorrection, diff --git a/src/context.rs b/src/context.rs index 5299fe9..ea1761a 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,9 +1,14 @@ use std::process::Command; +use std::time::Duration; use std::time::{SystemTime, UNIX_EPOCH}; use serde::{Deserialize, Serialize}; use serde_json::Value; +use crate::runtime_guards::run_command_output_with_timeout; + +const CONTEXT_COMMAND_TIMEOUT: Duration = Duration::from_millis(250); + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct TypingContext { pub focus_fingerprint: String, @@ -71,10 +76,7 @@ fn capture_niri_context(captured_at_ms: u64) -> Option { return None; } - let output = Command::new("niri") - .args(["msg", "-j", "focused-window"]) - .output() - .ok()?; + let output = run_context_command("niri", &["msg", "-j", "focused-window"])?; if !output.status.success() { return None; } @@ -86,10 +88,7 @@ fn capture_niri_context(captured_at_ms: u64) -> Option { fn capture_hyprland_context(captured_at_ms: u64) -> Option { std::env::var_os("HYPRLAND_INSTANCE_SIGNATURE")?; - let output = Command::new("hyprctl") - .args(["activewindow", "-j"]) - .output() - .ok()?; + let output = run_context_command("hyprctl", &["activewindow", "-j"])?; if !output.status.success() { return None; } @@ -131,6 +130,18 @@ fn parse_niri_focused_window_json(raw: &str, captured_at_ms: u64) -> Option Option { + let mut command = Command::new(program); + command.args(args); + match run_command_output_with_timeout(&mut command, CONTEXT_COMMAND_TIMEOUT) { + Ok(output) => Some(output), + Err(err) => { + tracing::debug!("context command {program} failed: {err}"); + None + } + } +} + fn parse_hyprland_activewindow_json(raw: &str, captured_at_ms: u64) -> Option { let value: Value = serde_json::from_str(raw).ok()?; let address = json_string(&value, "address").unwrap_or_default(); @@ -325,6 +336,7 @@ fn now_ms() -> u64 { #[cfg(test)] mod tests { use super::*; + use crate::test_support::{EnvVarGuard, env_lock, set_env, unique_temp_dir}; #[test] fn classify_surface_kind_detects_terminal() { @@ -404,4 +416,26 @@ mod tests { Some("news.ycombinator.com") ); } + + #[test] + fn capture_typing_context_returns_unknown_when_niri_hangs() { + let _env_lock = env_lock(); + let _guard = EnvVarGuard::capture(&["PATH", "XDG_CURRENT_DESKTOP"]); + let bin_dir = unique_temp_dir("context-timeout-bin"); + let niri_path = bin_dir.join("niri"); + std::fs::write(&niri_path, "#!/bin/sh\n/bin/sleep 5\n").expect("write niri"); + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(&niri_path, std::fs::Permissions::from_mode(0o755)) + .expect("chmod niri"); + } + set_env("PATH", bin_dir.to_str().expect("bin dir")); + set_env("XDG_CURRENT_DESKTOP", "niri"); + + let context = capture_typing_context(); + + assert!(!context.is_known_focus()); + assert_eq!(context.surface_kind, SurfaceKind::Unknown); + } } diff --git a/src/faster_whisper.rs b/src/faster_whisper.rs index 9c01336..bc70a1d 100644 --- a/src/faster_whisper.rs +++ b/src/faster_whisper.rs @@ -13,10 +13,15 @@ use tokio::process::Command; use crate::asr_protocol::{AsrRequest, AsrResponse}; use crate::config::{TranscriptionBackend, TranscriptionConfig, data_dir}; use crate::error::{Result, WhsprError}; +use crate::runtime_guards::run_command_status_with_timeout; use crate::transcribe::Transcript; const PYTHON_WORKER_SOURCE: &str = include_str!("faster_whisper_worker.py"); const RUNTIME_READY_MARKER: &str = ".runtime-ready"; +const REQUEST_TIMEOUT: Duration = Duration::from_secs(60); +const RUNTIME_SETUP_TIMEOUT: Duration = Duration::from_secs(1_800); +const MODEL_DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(7_200); +const VERSION_PROBE_TIMEOUT: Duration = Duration::from_secs(2); struct ManagedModelInfo { name: &'static str, @@ -73,7 +78,8 @@ pub async fn download_managed_model(name: &str) -> Result<()> { "Preparing faster-whisper runtime and model {}...", model.name )); - let status = std::process::Command::new(&python) + let mut command = std::process::Command::new(&python); + command .arg(&script) .arg("download") .arg("--repo-id") @@ -81,8 +87,8 @@ pub async fn download_managed_model(name: &str) -> Result<()> { .arg("--model-dir") .arg(&model_dir) .stdout(crate::ui::child_stdio()) - .stderr(crate::ui::child_stdio()) - .status() + .stderr(crate::ui::child_stdio()); + let status = run_command_status_with_timeout(&mut command, MODEL_DOWNLOAD_TIMEOUT) .map_err(|e| { WhsprError::Download(format!( "failed to start faster-whisper downloader via {}: {e}", @@ -265,23 +271,31 @@ impl FasterWhisperService { } pub async fn transcribe(&self, audio: &[f32], sample_rate: u32) -> Result { - let timeout = Duration::from_millis(60_000); + self.transcribe_with_timeout(audio, sample_rate, REQUEST_TIMEOUT) + .await + } + + async fn transcribe_with_timeout( + &self, + audio: &[f32], + sample_rate: u32, + timeout: Duration, + ) -> Result { self.ensure_running(timeout).await?; + let deadline = tokio::time::Instant::now() + timeout; - let mut stream = tokio::time::timeout(timeout, UnixStream::connect(&self.socket_path)) - .await - .map_err(|_| { - WhsprError::Transcription(format!( - "faster-whisper worker timed out after {}ms", - timeout.as_millis() - )) - })? - .map_err(|e| { - WhsprError::Transcription(format!( - "failed to connect to faster-whisper worker at {}: {e}", - self.socket_path.display() - )) - })?; + let mut stream = tokio::time::timeout( + remaining_request_budget(deadline, timeout)?, + UnixStream::connect(&self.socket_path), + ) + .await + .map_err(|_| asr_timeout_error(timeout))? + .map_err(|e| { + WhsprError::Transcription(format!( + "failed to connect to faster-whisper worker at {}: {e}", + self.socket_path.display() + )) + })?; let mut audio_bytes = Vec::with_capacity(std::mem::size_of_val(audio)); for sample in audio { @@ -293,26 +307,27 @@ impl FasterWhisperService { }) .map_err(|e| WhsprError::Transcription(format!("failed to encode ASR request: {e}")))?; payload.push(b'\n'); - stream - .write_all(&payload) - .await - .map_err(|e| WhsprError::Transcription(format!("failed to send ASR request: {e}")))?; - stream - .flush() + tokio::time::timeout( + remaining_request_budget(deadline, timeout)?, + stream.write_all(&payload), + ) + .await + .map_err(|_| asr_timeout_error(timeout))? + .map_err(|e| WhsprError::Transcription(format!("failed to send ASR request: {e}")))?; + tokio::time::timeout(remaining_request_budget(deadline, timeout)?, stream.flush()) .await + .map_err(|_| asr_timeout_error(timeout))? .map_err(|e| WhsprError::Transcription(format!("failed to flush ASR request: {e}")))?; let mut reader = BufReader::new(stream); let mut line = String::new(); - tokio::time::timeout(timeout, reader.read_line(&mut line)) - .await - .map_err(|_| { - WhsprError::Transcription(format!( - "faster-whisper worker timed out after {}ms", - timeout.as_millis() - )) - })? - .map_err(|e| WhsprError::Transcription(format!("failed to read ASR response: {e}")))?; + tokio::time::timeout( + remaining_request_budget(deadline, timeout)?, + reader.read_line(&mut line), + ) + .await + .map_err(|_| asr_timeout_error(timeout))? + .map_err(|e| WhsprError::Transcription(format!("failed to read ASR response: {e}")))?; if line.trim().is_empty() { return Err(WhsprError::Transcription( @@ -329,6 +344,19 @@ impl FasterWhisperService { } } +fn remaining_request_budget(deadline: tokio::time::Instant, timeout: Duration) -> Result { + deadline + .checked_duration_since(tokio::time::Instant::now()) + .ok_or_else(|| asr_timeout_error(timeout)) +} + +fn asr_timeout_error(timeout: Duration) -> WhsprError { + WhsprError::Transcription(format!( + "faster-whisper worker timed out after {}ms", + timeout.as_millis() + )) +} + fn find_managed_model(name: &str) -> Option<&'static ManagedModelInfo> { MANAGED_MODELS.iter().find(|info| info.name == name) } @@ -478,12 +506,10 @@ fn ensure_runtime_sync() -> Result<()> { let python3 = system_python()?; let venv_dir = runtime_dir.join("venv"); - let status = std::process::Command::new(&python3) - .arg("-m") - .arg("venv") - .arg(&venv_dir) - .status() - .map_err(|e| { + let mut create_venv = std::process::Command::new(&python3); + create_venv.arg("-m").arg("venv").arg(&venv_dir); + let status = + run_command_status_with_timeout(&mut create_venv, RUNTIME_SETUP_TIMEOUT).map_err(|e| { WhsprError::Transcription(format!("failed to create faster-whisper venv: {e}")) })?; if !status.success() { @@ -493,11 +519,9 @@ fn ensure_runtime_sync() -> Result<()> { } let pip = venv_dir.join("bin").join("pip"); - let status = std::process::Command::new(&pip) - .arg("install") - .arg("--upgrade") - .arg("pip") - .status() + let mut bootstrap_pip = std::process::Command::new(&pip); + bootstrap_pip.arg("install").arg("--upgrade").arg("pip"); + let status = run_command_status_with_timeout(&mut bootstrap_pip, RUNTIME_SETUP_TIMEOUT) .map_err(|e| WhsprError::Transcription(format!("failed to bootstrap pip: {e}")))?; if !status.success() { return Err(WhsprError::Transcription(format!( @@ -505,12 +529,13 @@ fn ensure_runtime_sync() -> Result<()> { ))); } - let status = std::process::Command::new(&pip) + let mut install_runtime = std::process::Command::new(&pip); + install_runtime .arg("install") .arg("faster-whisper") .arg("huggingface-hub") - .arg("numpy") - .status() + .arg("numpy"); + let status = run_command_status_with_timeout(&mut install_runtime, RUNTIME_SETUP_TIMEOUT) .map_err(|e| { WhsprError::Transcription(format!("failed to install faster-whisper runtime: {e}")) })?; @@ -531,10 +556,9 @@ fn ensure_runtime_sync() -> Result<()> { fn system_python() -> Result { for candidate in ["python3", "python"] { - match std::process::Command::new(candidate) - .arg("--version") - .status() - { + let mut command = std::process::Command::new(candidate); + command.arg("--version"); + match run_command_status_with_timeout(&mut command, VERSION_PROBE_TIMEOUT) { Ok(status) if status.success() => return Ok(PathBuf::from(candidate)), _ => continue, } @@ -609,6 +633,7 @@ impl Drop for StartupLock { #[cfg(test)] mod tests { use super::*; + use crate::test_support::unique_temp_dir; #[test] fn managed_model_path_uses_data_dir() { @@ -661,4 +686,47 @@ mod tests { assert_eq!(dirs, vec![lib_dir]); } + + #[tokio::test] + async fn transcribe_times_out_when_request_write_stalls() { + let runtime_dir = unique_temp_dir("faster-whisper-request-timeout"); + let socket_path = runtime_dir.join("asr.sock"); + let listener = tokio::net::UnixListener::bind(&socket_path).expect("bind stalled socket"); + + let server = tokio::spawn(async move { + let (probe, _) = listener.accept().await.expect("accept readiness probe"); + drop(probe); + let (_stalled_request, _) = listener.accept().await.expect("accept ASR request"); + tokio::time::sleep(Duration::from_secs(1)).await; + }); + + let service = FasterWhisperService { + socket_path: socket_path.clone(), + lock_path: runtime_dir.join("asr.lock"), + model_path: PathBuf::from("/tmp/test-model"), + language: "en".into(), + use_gpu: false, + idle_timeout_ms: 0, + }; + + let err = service + .transcribe_with_timeout( + &oversized_audio(512 * 1024), + 16_000, + Duration::from_millis(50), + ) + .await + .expect_err("stalled ASR request should time out"); + let message = match err { + WhsprError::Transcription(message) => message, + other => panic!("unexpected error: {other:?}"), + }; + assert!(message.contains("faster-whisper worker timed out")); + + server.abort(); + } + + fn oversized_audio(samples: usize) -> Vec { + vec![0.0; samples] + } } diff --git a/src/inject/clipboard.rs b/src/inject/clipboard.rs index dc9b0e8..e9a7906 100644 --- a/src/inject/clipboard.rs +++ b/src/inject/clipboard.rs @@ -2,6 +2,7 @@ use std::process::{Command, Stdio}; use std::time::Duration; use crate::error::{Result, WhsprError}; +use crate::runtime_guards::wait_child_with_timeout; pub(super) struct ClipboardAdapter<'a> { wl_copy_bin: &'a str, @@ -66,7 +67,7 @@ pub(super) fn run_wl_copy_with_timeout( } if std::time::Instant::now() >= deadline { let _ = wl_copy.kill(); - let _ = wl_copy.wait(); + let _ = wait_child_with_timeout(&mut wl_copy, Duration::from_millis(100)); return Err(WhsprError::Injection(format!( "wl-copy timed out after {}ms", timeout.as_millis() diff --git a/src/inject/tests.rs b/src/inject/tests.rs index 45db51f..921c059 100644 --- a/src/inject/tests.rs +++ b/src/inject/tests.rs @@ -33,8 +33,8 @@ fn run_wl_copy_reports_non_zero_exit() { #[test] fn run_wl_copy_reports_timeout() { let err = clipboard::run_wl_copy_with_timeout( - "/bin/sh", - &[String::from("-c"), String::from("sleep 1")], + "/bin/sleep", + &[String::from("1")], "hello", std::time::Duration::from_millis(80), ) diff --git a/src/lib.rs b/src/lib.rs index 64658ec..cfa5a83 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,9 +26,13 @@ pub mod rewrite_model; pub mod rewrite_profile; pub mod rewrite_protocol; pub mod rewrite_worker; +pub(crate) mod runtime_diagnostics; +pub(crate) mod runtime_guards; pub mod runtime_support; +pub(crate) mod safe_fs; pub mod session; pub mod setup; +pub(crate) mod structured_text; #[cfg(test)] pub mod test_support; pub mod transcribe; diff --git a/src/main.rs b/src/main.rs index 90cee45..f8f05f2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -46,6 +46,14 @@ async fn transcribe_file(cli: &Cli, file: &Path, output: Option<&Path>, raw: boo let transcript = asr::execute::transcribe_audio(&config, prepared, samples, file_audio::TARGET_SAMPLE_RATE) .await?; + let runtime_text_resources = if raw { + None + } else { + Some(load_runtime_text_resources_or_default( + &config, + "file transcription", + )) + }; let text = if raw { postprocess::planning::raw_text(&transcript) @@ -54,6 +62,7 @@ async fn transcribe_file(cli: &Cli, file: &Path, output: Option<&Path>, raw: boo &config, transcript, rewrite_service.as_ref(), + runtime_text_resources.as_ref(), None, None, ) @@ -71,8 +80,22 @@ async fn transcribe_file(cli: &Cli, file: &Path, output: Option<&Path>, raw: boo Ok(()) } +fn load_runtime_text_resources_or_default( + config: &Config, + phase: &'static str, +) -> postprocess::planning::RuntimeTextResources { + let (resources, degraded) = + postprocess::planning::load_runtime_text_resources_with_status(config); + if degraded { + tracing::warn!( + "runtime text resources for {phase} were degraded; using available defaults" + ); + } + resources +} + async fn run_default(cli: &Cli) -> Result<()> { - let Some(_pid_lock) = runtime_support::acquire_or_signal_lock()? else { + let Some(pid_lock) = runtime_support::acquire_or_signal_lock()? else { return Ok(()); }; @@ -83,7 +106,7 @@ async fn run_default(cli: &Cli) -> Result<()> { asr::validation::validate_transcription_config(&config)?; tracing::debug!("config loaded: {config:?}"); - app::run(config).await + app::run(config, pid_lock).await } #[tokio::main] diff --git a/src/nemo_asr.rs b/src/nemo_asr.rs index 0ab2827..f6c6ab1 100644 --- a/src/nemo_asr.rs +++ b/src/nemo_asr.rs @@ -15,6 +15,7 @@ use tokio::process::Command; use crate::asr_protocol::{AsrRequest, AsrResponse}; use crate::config::{TranscriptionBackend, TranscriptionConfig, data_dir}; use crate::error::{Result, WhsprError}; +use crate::runtime_guards::{run_command_output_with_timeout, run_command_status_with_timeout}; use crate::transcribe::Transcript; const PYTHON_WORKER_SOURCE: &str = include_str!("nemo_asr_worker.py"); @@ -23,6 +24,9 @@ const MODEL_READY_METADATA: &str = ".model-ready.json"; const STARTUP_LOCK_STALE_AGE: Duration = Duration::from_secs(600); const STARTUP_READY_TIMEOUT: Duration = Duration::from_secs(240); const REQUEST_TIMEOUT: Duration = Duration::from_secs(120); +const RUNTIME_SETUP_TIMEOUT: Duration = Duration::from_secs(1_800); +const MODEL_DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(7_200); +const VERSION_PROBE_TIMEOUT: Duration = Duration::from_secs(2); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] @@ -131,7 +135,8 @@ pub async fn download_managed_model(name: &str) -> Result<()> { "Preparing experimental NeMo runtime and model {}...", model.name )); - let status = std::process::Command::new(&python) + let mut command = std::process::Command::new(&python); + command .arg(&script) .arg("download") .arg("--model-ref") @@ -143,8 +148,8 @@ pub async fn download_managed_model(name: &str) -> Result<()> { .arg("--cache-dir") .arg(nemo_cache_dir()) .stdout(crate::ui::child_stdio()) - .stderr(crate::ui::child_stdio()) - .status() + .stderr(crate::ui::child_stdio()); + let status = run_command_status_with_timeout(&mut command, MODEL_DOWNLOAD_TIMEOUT) .map_err(|e| { WhsprError::Download(format!( "failed to start NeMo downloader via {}: {e}", @@ -347,23 +352,31 @@ impl NemoAsrService { } pub async fn transcribe(&self, audio: &[f32], sample_rate: u32) -> Result { + self.transcribe_with_timeout(audio, sample_rate, REQUEST_TIMEOUT) + .await + } + + async fn transcribe_with_timeout( + &self, + audio: &[f32], + sample_rate: u32, + timeout: Duration, + ) -> Result { self.ensure_running(STARTUP_READY_TIMEOUT).await?; + let deadline = tokio::time::Instant::now() + timeout; - let mut stream = - tokio::time::timeout(REQUEST_TIMEOUT, UnixStream::connect(&self.socket_path)) - .await - .map_err(|_| { - WhsprError::Transcription(format!( - "NeMo ASR worker timed out after {}ms", - REQUEST_TIMEOUT.as_millis() - )) - })? - .map_err(|e| { - WhsprError::Transcription(format!( - "failed to connect to NeMo ASR worker at {}: {e}", - self.socket_path.display() - )) - })?; + let mut stream = tokio::time::timeout( + remaining_request_budget(deadline, timeout)?, + UnixStream::connect(&self.socket_path), + ) + .await + .map_err(|_| asr_timeout_error(timeout))? + .map_err(|e| { + WhsprError::Transcription(format!( + "failed to connect to NeMo ASR worker at {}: {e}", + self.socket_path.display() + )) + })?; let mut audio_bytes = Vec::with_capacity(std::mem::size_of_val(audio)); for sample in audio { @@ -375,26 +388,27 @@ impl NemoAsrService { }) .map_err(|e| WhsprError::Transcription(format!("failed to encode ASR request: {e}")))?; payload.push(b'\n'); - stream - .write_all(&payload) - .await - .map_err(|e| WhsprError::Transcription(format!("failed to send ASR request: {e}")))?; - stream - .flush() + tokio::time::timeout( + remaining_request_budget(deadline, timeout)?, + stream.write_all(&payload), + ) + .await + .map_err(|_| asr_timeout_error(timeout))? + .map_err(|e| WhsprError::Transcription(format!("failed to send ASR request: {e}")))?; + tokio::time::timeout(remaining_request_budget(deadline, timeout)?, stream.flush()) .await + .map_err(|_| asr_timeout_error(timeout))? .map_err(|e| WhsprError::Transcription(format!("failed to flush ASR request: {e}")))?; let mut reader = BufReader::new(stream); let mut line = String::new(); - tokio::time::timeout(REQUEST_TIMEOUT, reader.read_line(&mut line)) - .await - .map_err(|_| { - WhsprError::Transcription(format!( - "NeMo ASR worker timed out after {}ms", - REQUEST_TIMEOUT.as_millis() - )) - })? - .map_err(|e| WhsprError::Transcription(format!("failed to read ASR response: {e}")))?; + tokio::time::timeout( + remaining_request_budget(deadline, timeout)?, + reader.read_line(&mut line), + ) + .await + .map_err(|_| asr_timeout_error(timeout))? + .map_err(|e| WhsprError::Transcription(format!("failed to read ASR response: {e}")))?; if line.trim().is_empty() { return Err(WhsprError::Transcription( @@ -411,6 +425,19 @@ impl NemoAsrService { } } +fn remaining_request_budget(deadline: tokio::time::Instant, timeout: Duration) -> Result { + deadline + .checked_duration_since(tokio::time::Instant::now()) + .ok_or_else(|| asr_timeout_error(timeout)) +} + +fn asr_timeout_error(timeout: Duration) -> WhsprError { + WhsprError::Transcription(format!( + "NeMo ASR worker timed out after {}ms", + timeout.as_millis() + )) +} + fn find_managed_model(name: &str) -> Option<&'static ManagedModelInfo> { MANAGED_MODELS.iter().find(|info| info.name == name) } @@ -522,11 +549,9 @@ fn ensure_runtime_sync(family: NemoModelFamily) -> Result<()> { let python3 = system_python()?; let venv_dir = runtime_dir; - let status = std::process::Command::new(&python3) - .arg("-m") - .arg("venv") - .arg(&venv_dir) - .status() + let mut create_venv = std::process::Command::new(&python3); + create_venv.arg("-m").arg("venv").arg(&venv_dir); + let status = run_command_status_with_timeout(&mut create_venv, RUNTIME_SETUP_TIMEOUT) .map_err(|e| WhsprError::Transcription(format!("failed to create NeMo venv: {e}")))?; if !status.success() { return Err(WhsprError::Transcription(format!( @@ -535,13 +560,14 @@ fn ensure_runtime_sync(family: NemoModelFamily) -> Result<()> { } let pip = venv_dir.join("bin").join("pip"); - let status = std::process::Command::new(&pip) + let mut bootstrap_pip = std::process::Command::new(&pip); + bootstrap_pip .arg("install") .arg("--upgrade") .arg("pip") .arg("setuptools") - .arg("wheel") - .status() + .arg("wheel"); + let status = run_command_status_with_timeout(&mut bootstrap_pip, RUNTIME_SETUP_TIMEOUT) .map_err(|e| WhsprError::Transcription(format!("failed to bootstrap pip: {e}")))?; if !status.success() { return Err(WhsprError::Transcription(format!( @@ -581,11 +607,11 @@ fn system_python() -> Result { } fn python_version>(python: S) -> Option<(u32, u32)> { - let output = std::process::Command::new(python) + let mut command = std::process::Command::new(python); + command .arg("-c") - .arg("import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") - .output() - .ok()?; + .arg("import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')"); + let output = run_command_output_with_timeout(&mut command, VERSION_PROBE_TIMEOUT).ok()?; if !output.status.success() { return None; } @@ -626,9 +652,9 @@ fn install_runtime_packages(pip: &Path, profile: NemoRuntimeProfile) -> Result<( } fn run_pip_install(pip: &Path, args: &[&str], label: &str) -> Result<()> { - let status = std::process::Command::new(pip) - .args(args) - .status() + let mut command = std::process::Command::new(pip); + command.args(args); + let status = run_command_status_with_timeout(&mut command, RUNTIME_SETUP_TIMEOUT) .map_err(|e| WhsprError::Transcription(format!("failed to install {label}: {e}")))?; if !status.success() { return Err(WhsprError::Transcription(format!( @@ -742,6 +768,7 @@ impl Drop for StartupLock { mod tests { use super::*; use crate::config::TranscriptionConfig; + use crate::test_support::unique_temp_dir; use std::sync::{Mutex, OnceLock}; fn test_dir_guard() -> std::sync::MutexGuard<'static, ()> { @@ -815,4 +842,48 @@ mod tests { let config = TranscriptionConfig::default(); assert!(prepare_service(&config).is_none()); } + + #[tokio::test] + async fn transcribe_times_out_when_request_write_stalls() { + let runtime_dir = unique_temp_dir("nemo-request-timeout"); + let socket_path = runtime_dir.join("asr.sock"); + let listener = tokio::net::UnixListener::bind(&socket_path).expect("bind stalled socket"); + + let server = tokio::spawn(async move { + let (probe, _) = listener.accept().await.expect("accept readiness probe"); + drop(probe); + let (_stalled_request, _) = listener.accept().await.expect("accept ASR request"); + tokio::time::sleep(Duration::from_secs(1)).await; + }); + + let service = NemoAsrService { + socket_path: socket_path.clone(), + lock_path: runtime_dir.join("asr.lock"), + model_ref: "test-model".into(), + family: NemoModelFamily::Parakeet, + language: "en".into(), + use_gpu: false, + idle_timeout_ms: 0, + }; + + let err = service + .transcribe_with_timeout( + &oversized_audio(512 * 1024), + 16_000, + Duration::from_millis(50), + ) + .await + .expect_err("stalled ASR request should time out"); + let message = match err { + WhsprError::Transcription(message) => message, + other => panic!("unexpected error: {other:?}"), + }; + assert!(message.contains("NeMo ASR worker timed out")); + + server.abort(); + } + + fn oversized_audio(samples: usize) -> Vec { + vec![0.0; samples] + } } diff --git a/src/personalization/mod.rs b/src/personalization/mod.rs index d6105e6..5f1f2aa 100644 --- a/src/personalization/mod.rs +++ b/src/personalization/mod.rs @@ -1,5 +1,6 @@ use crate::config::Config; use crate::error::Result; +use crate::structured_text; mod rewrite; mod store; @@ -59,6 +60,9 @@ pub fn load_rules(config: &Config) -> Result { pub fn finalize_text(text: &str, rules: &PersonalizationRules) -> String { let corrected = apply_dictionary(text, rules); let expanded = expand_snippets(&corrected, rules); + if let Some(normalized) = structured_text::normalize_strict_structured_text(&expanded) { + return normalized; + } normalize_numeric_dot_runs(&expanded) } @@ -448,6 +452,12 @@ mod tests { assert_eq!(finalized, "MPL 2.0 and TLS 1.3 are common references"); } + #[test] + fn finalize_text_collapses_spaced_structured_punctuation_runs() { + let finalized = finalize_text("portfolio . notes . supply", &rules()); + assert_eq!(finalized, "portfolio.notes.supply"); + } + #[test] fn finalize_text_preserves_sentence_period_before_words() { let finalized = finalize_text("Section 2. Next step", &rules()); diff --git a/src/personalization/rewrite.rs b/src/personalization/rewrite.rs index 4c4b455..5f80041 100644 --- a/src/personalization/rewrite.rs +++ b/src/personalization/rewrite.rs @@ -6,6 +6,7 @@ use crate::rewrite_protocol::{ RewriteIntentConfidence, RewritePolicyContext, RewriteReplacementScope, RewriteTailShape, RewriteTranscript, RewriteTranscriptSegment, }; +use crate::structured_text; use crate::transcribe::Transcript; use super::{PersonalizationRules, WordSpan, apply_dictionary, collect_word_spans}; @@ -170,6 +171,12 @@ fn build_rewrite_candidates( raw_text, rules, ); + push_structured_literal_candidate( + &mut candidates, + raw_text, + correction_aware_text, + aggressive_correction_text, + ); push_rewrite_candidate( &mut candidates, RewriteCandidateKind::ConservativeCorrection, @@ -328,6 +335,34 @@ fn build_rewrite_candidates( candidates } +fn push_structured_literal_candidate( + candidates: &mut Vec, + raw_text: &str, + correction_aware_text: &str, + aggressive_correction_text: Option<&str>, +) { + let structured = structured_text::extract_structured_candidate(raw_text) + .or_else(|| structured_text::extract_structured_candidate(correction_aware_text)) + .or_else(|| { + aggressive_correction_text.and_then(structured_text::extract_structured_candidate) + }); + let Some(structured) = structured else { + return; + }; + + if candidates.iter().any(|candidate| { + candidate.kind == RewriteCandidateKind::StructuredLiteral + && candidate.text == structured.normalized + }) { + return; + } + + candidates.push(RewriteCandidate { + kind: RewriteCandidateKind::StructuredLiteral, + text: structured.normalized, + }); +} + fn push_rewrite_candidate( candidates: &mut Vec, kind: RewriteCandidateKind, @@ -684,15 +719,16 @@ fn candidate_priority(kind: RewriteCandidateKind) -> u8 { RewriteCandidateKind::SpanReplacement => 0, RewriteCandidateKind::ClauseReplacement => 1, RewriteCandidateKind::SentenceReplacement => 2, - RewriteCandidateKind::ContextualReplacement => 3, - RewriteCandidateKind::AggressiveCorrection => 4, - RewriteCandidateKind::CancelPreviousSentence => 5, - RewriteCandidateKind::CancelPreviousClause => 6, - RewriteCandidateKind::FollowingReplacement => 7, - RewriteCandidateKind::GlossaryCorrection => 8, - RewriteCandidateKind::ConservativeCorrection => 9, - RewriteCandidateKind::DropCueOnly => 10, - RewriteCandidateKind::Literal => 11, + RewriteCandidateKind::StructuredLiteral => 3, + RewriteCandidateKind::ContextualReplacement => 4, + RewriteCandidateKind::AggressiveCorrection => 5, + RewriteCandidateKind::CancelPreviousSentence => 6, + RewriteCandidateKind::CancelPreviousClause => 7, + RewriteCandidateKind::FollowingReplacement => 8, + RewriteCandidateKind::GlossaryCorrection => 9, + RewriteCandidateKind::ConservativeCorrection => 10, + RewriteCandidateKind::DropCueOnly => 11, + RewriteCandidateKind::Literal => 12, } } @@ -875,6 +911,61 @@ mod tests { assert!(rewrite.recommended_candidate.is_some()); } + #[test] + fn build_rewrite_transcript_adds_structured_literal_candidate_for_domains() { + let transcript = Transcript { + raw_text: "portfolio. Notes. Supply is the URL".into(), + detected_language: Some("en".into()), + segments: Vec::new(), + }; + + let rewrite = build_rewrite_transcript(&transcript, &rules()); + assert!(rewrite.rewrite_candidates.iter().any(|candidate| { + candidate.kind == RewriteCandidateKind::StructuredLiteral + && candidate.text == "portfolio.notes.supply" + })); + } + + #[test] + fn build_rewrite_transcript_keeps_structured_literal_kind_for_exact_literal_text() { + let transcript = Transcript { + raw_text: "https://Example.com/Repo?x=1#Frag".into(), + detected_language: Some("en".into()), + segments: Vec::new(), + }; + + let rewrite = build_rewrite_transcript(&transcript, &rules()); + let kinds = rewrite + .rewrite_candidates + .iter() + .filter(|candidate| candidate.text == "https://Example.com/Repo?x=1#Frag") + .map(|candidate| candidate.kind) + .collect::>(); + + assert!(kinds.contains(&RewriteCandidateKind::Literal)); + assert!(kinds.contains(&RewriteCandidateKind::StructuredLiteral)); + } + + #[test] + fn build_rewrite_transcript_preserves_prefixed_structured_literal_text() { + let transcript = Transcript { + raw_text: "/api/v1".into(), + detected_language: Some("en".into()), + segments: Vec::new(), + }; + + let rewrite = build_rewrite_transcript(&transcript, &rules()); + let kinds = rewrite + .rewrite_candidates + .iter() + .filter(|candidate| candidate.text == "/api/v1") + .map(|candidate| candidate.kind) + .collect::>(); + + assert!(kinds.contains(&RewriteCandidateKind::Literal)); + assert!(kinds.contains(&RewriteCandidateKind::StructuredLiteral)); + } + #[test] fn contextual_replacement_can_preserve_unlike_prefix() { let transcript = Transcript { diff --git a/src/personalization/store.rs b/src/personalization/store.rs index ffe0af9..018652f 100644 --- a/src/personalization/store.rs +++ b/src/personalization/store.rs @@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize}; use crate::config::Config; use crate::error::{Result, WhsprError}; +use crate::safe_fs; use super::normalized_words; @@ -127,7 +128,7 @@ pub(super) fn load_custom_instructions(config: &Config) -> Result { return Ok(String::new()); }; - match std::fs::read_to_string(&path) { + match safe_fs::read_to_string(&path) { Ok(contents) => Ok(contents.trim().to_string()), Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(String::new()), Err(err) => Err(WhsprError::Config(format!( @@ -142,7 +143,7 @@ pub(super) fn read_dictionary_file(path: &Path) -> Result> return Ok(Vec::new()); } - let contents = std::fs::read_to_string(path).map_err(|e| { + let contents = safe_fs::read_to_string(path).map_err(|e| { WhsprError::Config(format!("failed to read dictionary {}: {e}", path.display())) })?; let file: DictionaryFile = toml::from_str(&contents).map_err(|e| { @@ -161,7 +162,7 @@ pub(super) fn write_dictionary_file(path: &Path, entries: &[DictionaryEntry]) -> }; let contents = toml::to_string_pretty(&file) .map_err(|e| WhsprError::Config(format!("failed to encode dictionary: {e}")))?; - std::fs::write(path, contents).map_err(|e| { + safe_fs::write(path, contents).map_err(|e| { WhsprError::Config(format!( "failed to write dictionary {}: {e}", path.display() @@ -175,7 +176,7 @@ pub(super) fn read_snippet_file(path: &Path) -> Result> { return Ok(Vec::new()); } - let contents = std::fs::read_to_string(path).map_err(|e| { + let contents = safe_fs::read_to_string(path).map_err(|e| { WhsprError::Config(format!("failed to read snippets {}: {e}", path.display())) })?; let file: SnippetFile = toml::from_str(&contents).map_err(|e| { @@ -191,7 +192,7 @@ pub(super) fn write_snippet_file(path: &Path, snippets: &[SnippetEntry]) -> Resu }; let contents = toml::to_string_pretty(&file) .map_err(|e| WhsprError::Config(format!("failed to encode snippets: {e}")))?; - std::fs::write(path, contents).map_err(|e| { + safe_fs::write(path, contents).map_err(|e| { WhsprError::Config(format!("failed to write snippets {}: {e}", path.display())) })?; Ok(()) diff --git a/src/postprocess/finalize.rs b/src/postprocess/finalize.rs index a1fc431..2473a33 100644 --- a/src/postprocess/finalize.rs +++ b/src/postprocess/finalize.rs @@ -1,18 +1,45 @@ use std::time::Duration; use std::time::Instant; +use std::{ + io, + panic::{AssertUnwindSafe, catch_unwind}, + thread, +}; use crate::agentic_rewrite; use crate::config::{Config, PostprocessMode, RewriteBackend, RewriteFallback}; use crate::context::TypingContext; +use crate::error::WhsprError; use crate::personalization::{self, PersonalizationRules}; -use crate::rewrite_protocol::{RewriteCorrectionPolicy, RewriteTranscript}; +use crate::rewrite_protocol::{RewriteCandidateKind, RewriteCorrectionPolicy, RewriteTranscript}; use crate::rewrite_worker::RewriteService; use crate::session::{EligibleSessionEntry, SessionRewriteSummary}; +use crate::structured_text; use crate::transcribe::Transcript; use super::{execution, planning}; const FEEDBACK_DRAIN_DELAY: Duration = Duration::from_millis(150); +const REWRITE_PLAN_TIMEOUT: Duration = Duration::from_secs(5); + +#[derive(Debug)] +enum TimedPlanningError { + Spawn(io::Error), + Timeout(Duration), + Panic, + ChannelClosed, +} + +impl TimedPlanningError { + fn describe(&self, phase: &str) -> String { + match self { + Self::Spawn(err) => format!("failed to start {phase} worker: {err}"), + Self::Timeout(timeout) => format!("{phase} timed out after {}ms", timeout.as_millis()), + Self::Panic => format!("{phase} worker panicked"), + Self::ChannelClosed => format!("{phase} worker exited without a result"), + } + } +} #[derive(Debug, Clone, PartialEq, Eq)] pub enum FinalizedOperation { @@ -28,19 +55,21 @@ pub struct FinalizedTranscript { pub text: String, pub operation: FinalizedOperation, pub rewrite_summary: SessionRewriteSummary, + pub degraded_reason: Option, } pub async fn finalize_transcript( config: &Config, transcript: Transcript, rewrite_service: Option<&RewriteService>, + runtime_resources: Option<&planning::RuntimeTextResources>, typing_context: Option<&TypingContext>, recent_session: Option<&EligibleSessionEntry>, ) -> FinalizedTranscript { let started = Instant::now(); let finalized = match config.postprocess.mode { PostprocessMode::Raw => { - let rules = planning::load_runtime_rules(config); + let resources = load_runtime_text_resources_or_default(config, runtime_resources).await; finalize_plain_text( planning::raw_text(&transcript), SessionRewriteSummary { @@ -48,11 +77,11 @@ pub async fn finalize_transcript( rewrite_used: false, recommended_candidate: None, }, - &rules, + &resources.rules, ) } PostprocessMode::LegacyBasic => { - let rules = planning::load_runtime_rules(config); + let resources = load_runtime_text_resources_or_default(config, runtime_resources).await; finalize_plain_text( crate::cleanup::clean_transcript(&transcript, &config.cleanup), SessionRewriteSummary { @@ -60,16 +89,22 @@ pub async fn finalize_transcript( rewrite_used: false, recommended_candidate: None, }, - &rules, + &resources.rules, ) } PostprocessMode::Rewrite => { - finalize_rewrite_plan_or_fallback( + match build_rewrite_plan_with_timeout( config, - rewrite_service, - planning::build_rewrite_plan(config, &transcript, typing_context, recent_session), + runtime_resources, + &transcript, + typing_context, + recent_session, ) .await + { + Ok(plan) => finalize_rewrite_plan_or_fallback(config, rewrite_service, plan).await, + Err(err) => finalize_rewrite_planning_failure(&transcript, &err), + } } }; tracing::info!( @@ -86,6 +121,95 @@ pub async fn wait_for_feedback_drain() { tokio::time::sleep(FEEDBACK_DRAIN_DELAY).await; } +async fn load_runtime_text_resources_or_default( + config: &Config, + runtime_resources: Option<&planning::RuntimeTextResources>, +) -> planning::RuntimeTextResources { + if let Some(resources) = runtime_resources { + return resources.clone(); + } + + let (resources, degraded) = planning::load_runtime_text_resources_with_status(config); + if degraded { + tracing::warn!("failed to load runtime text resources; using available defaults"); + } + resources +} + +async fn build_rewrite_plan_with_timeout( + config: &Config, + runtime_resources: Option<&planning::RuntimeTextResources>, + transcript: &Transcript, + typing_context: Option<&TypingContext>, + recent_session: Option<&EligibleSessionEntry>, +) -> crate::error::Result { + let config = config.clone(); + let runtime_resources = runtime_resources.cloned(); + let transcript = transcript.clone(); + let typing_context = typing_context.cloned(); + let recent_session = recent_session.cloned(); + + run_rewrite_task_with_timeout(REWRITE_PLAN_TIMEOUT, "rewrite planning", move || { + let runtime_resources = + runtime_resources.unwrap_or_else(|| planning::load_runtime_text_resources(&config)); + planning::build_rewrite_plan( + &config, + &runtime_resources, + &transcript, + typing_context.as_ref(), + recent_session.as_ref(), + ) + }) + .await + .map_err(|err| WhsprError::Rewrite(err.describe("rewrite planning"))) +} + +async fn run_rewrite_task_with_timeout( + timeout: Duration, + phase: &'static str, + task: F, +) -> Result +where + T: Send + 'static, + F: FnOnce() -> T + Send + 'static, +{ + let (tx, rx) = tokio::sync::oneshot::channel(); + // Rewrite planning runs at most once per dictation process, so a dedicated + // thread avoids head-of-line blocking from unrelated timeout-guarded work. + thread::Builder::new() + .name(format!("whispers-{}", phase.replace(' ', "-"))) + .spawn(move || { + let result = + catch_unwind(AssertUnwindSafe(task)).map_err(|_| TimedPlanningError::Panic); + let _ = tx.send(result); + }) + .map_err(TimedPlanningError::Spawn)?; + + match tokio::time::timeout(timeout, rx).await { + Ok(Ok(Ok(value))) => Ok(value), + Ok(Ok(Err(err))) => Err(err), + Ok(Err(_)) => Err(TimedPlanningError::ChannelClosed), + Err(_) => Err(TimedPlanningError::Timeout(timeout)), + } +} + +fn finalize_rewrite_planning_failure( + transcript: &Transcript, + err: &crate::error::WhsprError, +) -> FinalizedTranscript { + tracing::warn!("rewrite planning failed: {err}; using raw transcript fallback"); + FinalizedTranscript { + text: planning::raw_text(transcript), + operation: FinalizedOperation::Append, + rewrite_summary: SessionRewriteSummary { + had_edit_cues: false, + rewrite_used: false, + recommended_candidate: None, + }, + degraded_reason: Some("rewrite_planning_failed".into()), + } +} + async fn finalize_rewrite_plan_or_fallback( config: &Config, rewrite_service: Option<&RewriteService>, @@ -119,18 +243,24 @@ fn finalize_rewrite_attempt( plan: planning::RewritePlan, rewrite_result: crate::error::Result, ) -> FinalizedTranscript { - let (base, rewrite_used) = match rewrite_result { + let (base, rewrite_used, degraded_reason) = match rewrite_result { Ok(text) if rewrite_output_accepted(config, &plan.rewrite_transcript, &text) => { + let text = + canonicalize_structured_output(&plan.rewrite_transcript, &text).unwrap_or(text); tracing::debug!( output_len = text.len(), mode = config.postprocess.mode.as_str(), "rewrite applied successfully" ); - (text, true) + (text, true, None) } Ok(text) if text.trim().is_empty() => { tracing::warn!("rewrite model returned empty text; using fallback"); - (plan.fallback_text, false) + ( + plan.fallback_text, + false, + Some("rewrite_empty_output".into()), + ) } Ok(text) => { tracing::warn!( @@ -138,11 +268,19 @@ fn finalize_rewrite_attempt( output_len = text.len(), "rewrite output failed acceptance guard; using fallback" ); - (plan.fallback_text, false) + ( + plan.fallback_text, + false, + Some("rewrite_output_rejected".into()), + ) } Err(err) => { tracing::warn!("rewrite failed: {err}; using fallback"); - (plan.fallback_text, false) + ( + plan.fallback_text, + false, + Some("rewrite_execution_failed".into()), + ) } }; @@ -155,6 +293,7 @@ fn finalize_rewrite_attempt( }, &plan.rules, ) + .with_degraded_reason(degraded_reason) .with_operation(plan.operation) } @@ -167,6 +306,7 @@ fn finalize_plain_text( text: personalization::finalize_text(&text, rules), operation: FinalizedOperation::Append, rewrite_summary, + degraded_reason: None, } } @@ -180,6 +320,7 @@ fn finalize_unavailable_rewrite_fallback(plan: planning::RewritePlan) -> Finaliz }, &plan.rules, ) + .with_degraded_reason(Some("rewrite_unavailable".into())) .with_operation(plan.operation) } @@ -188,6 +329,11 @@ impl FinalizedTranscript { self.operation = operation; self } + + fn with_degraded_reason(mut self, degraded_reason: Option) -> Self { + self.degraded_reason = degraded_reason; + self + } } fn rewrite_output_accepted( @@ -199,6 +345,15 @@ fn rewrite_output_accepted( return false; } + if let Some(candidate) = strict_structured_literal_candidate(rewrite_transcript) { + return structured_text::output_matches_candidate(text, candidate); + } + if structured_literal_candidate(rewrite_transcript) + .is_some_and(|candidate| structured_text::output_matches_candidate(text, candidate)) + { + return false; + } + match rewrite_transcript.policy_context.correction_policy { RewriteCorrectionPolicy::Conservative => { agentic_rewrite::conservative_output_allowed(rewrite_transcript, text) @@ -207,15 +362,51 @@ fn rewrite_output_accepted( } } +fn structured_literal_candidate(rewrite_transcript: &RewriteTranscript) -> Option<&str> { + rewrite_transcript + .rewrite_candidates + .iter() + .find(|candidate| candidate.kind == RewriteCandidateKind::StructuredLiteral) + .map(|candidate| candidate.text.as_str()) +} + +fn strict_structured_literal_candidate(rewrite_transcript: &RewriteTranscript) -> Option<&str> { + let candidate = structured_literal_candidate(rewrite_transcript)?; + structured_literal_source_matches_candidate(rewrite_transcript, candidate).then_some(candidate) +} + +fn structured_literal_source_matches_candidate( + rewrite_transcript: &RewriteTranscript, + candidate: &str, +) -> bool { + [ + Some(rewrite_transcript.raw_text.as_str()), + Some(rewrite_transcript.correction_aware_text.as_str()), + rewrite_transcript.aggressive_correction_text.as_deref(), + ] + .into_iter() + .flatten() + .any(|text| structured_text::output_matches_candidate(text, candidate)) +} + +fn canonicalize_structured_output( + rewrite_transcript: &RewriteTranscript, + text: &str, +) -> Option { + let candidate = strict_structured_literal_candidate(rewrite_transcript)?; + structured_text::output_matches_candidate(text, candidate).then(|| candidate.to_string()) +} + #[cfg(test)] mod tests { use std::path::PathBuf; + use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; use super::*; use crate::rewrite_protocol::{ RewriteCandidate, RewriteCandidateKind, RewritePolicyContext, RewriteTranscript, }; - fn plan_config(mode: PostprocessMode, backend: RewriteBackend) -> Config { let mut config = Config::default(); config.postprocess.mode = mode; @@ -305,4 +496,201 @@ mod tests { assert_eq!(finalized.text, "fallback text"); assert!(!finalized.rewrite_summary.rewrite_used); } + + #[tokio::test] + async fn rewrite_thread_timeout_returns_value_when_task_finishes() { + let result = run_rewrite_task_with_timeout(Duration::from_millis(50), "test task", || 7) + .await + .expect("task should finish"); + + assert_eq!(result, 7); + } + + #[tokio::test] + async fn rewrite_thread_timeout_returns_error_when_task_hangs() { + let completed = Arc::new(AtomicBool::new(false)); + let completed_in_thread = Arc::clone(&completed); + + let err = + run_rewrite_task_with_timeout(Duration::from_millis(10), "test task", move || { + std::thread::sleep(Duration::from_millis(50)); + completed_in_thread.store(true, Ordering::Relaxed); + }) + .await + .expect_err("task should time out"); + + assert_eq!(err.describe("test task"), "test task timed out after 10ms"); + + assert!(!completed.load(Ordering::Relaxed)); + } + + #[test] + fn rewrite_planning_failure_falls_back_to_trimmed_raw_text() { + let transcript = Transcript { + raw_text: " raw transcript ".into(), + detected_language: Some("en".into()), + segments: Vec::new(), + }; + + let finalized = finalize_rewrite_planning_failure( + &transcript, + &WhsprError::Rewrite("rewrite planning timed out after 5000ms".into()), + ); + + assert_eq!(finalized.text, "raw transcript"); + assert!(!finalized.rewrite_summary.rewrite_used); + assert_eq!(finalized.operation, FinalizedOperation::Append); + assert_eq!( + finalized.degraded_reason.as_deref(), + Some("rewrite_planning_failed") + ); + } + + #[test] + fn structured_literal_meta_wrapper_is_canonicalized() { + let config = plan_config(PostprocessMode::Rewrite, RewriteBackend::Cloud); + let mut plan = rewrite_plan(); + plan.rewrite_transcript.raw_text = "portfolio. Notes. Supply".into(); + plan.rewrite_transcript.correction_aware_text = "portfolio. Notes. Supply".into(); + plan.fallback_text = "portfolio.notes.supply".into(); + plan.rewrite_transcript.rewrite_candidates = vec![RewriteCandidate { + kind: RewriteCandidateKind::StructuredLiteral, + text: "portfolio.notes.supply".into(), + }]; + + let finalized = finalize_rewrite_attempt( + &config, + plan, + Ok("portfolio. Notes. Supply is the URL".into()), + ); + + assert_eq!(finalized.text, "portfolio.notes.supply"); + assert!(finalized.rewrite_summary.rewrite_used); + } + + #[test] + fn structured_literal_url_with_query_and_fragment_meta_wrapper_is_canonicalized() { + let config = plan_config(PostprocessMode::Rewrite, RewriteBackend::Cloud); + let mut plan = rewrite_plan(); + plan.rewrite_transcript.raw_text = "https://example.com/?q=test&lang=en#frag".into(); + plan.rewrite_transcript.correction_aware_text = + "https://example.com/?q=test&lang=en#frag".into(); + plan.fallback_text = "https://example.com/?q=test&lang=en#frag".into(); + plan.rewrite_transcript.rewrite_candidates = vec![RewriteCandidate { + kind: RewriteCandidateKind::StructuredLiteral, + text: "https://example.com/?q=test&lang=en#frag".into(), + }]; + + let finalized = finalize_rewrite_attempt( + &config, + plan, + Ok("https://example.com/?q=test&lang=en#frag is the URL".into()), + ); + + assert_eq!(finalized.text, "https://example.com/?q=test&lang=en#frag"); + assert!(finalized.rewrite_summary.rewrite_used); + } + + #[test] + fn exact_structured_literal_still_canonicalizes_meta_wrapped_rewrite_output() { + let config = plan_config(PostprocessMode::Rewrite, RewriteBackend::Cloud); + let transcript = Transcript { + raw_text: "https://Example.com/Repo?x=1#Frag".into(), + detected_language: Some("en".into()), + segments: Vec::new(), + }; + let plan = planning::build_rewrite_plan( + &config, + &planning::load_runtime_text_resources(&config), + &transcript, + None, + None, + ); + + let finalized = finalize_rewrite_attempt( + &config, + plan, + Ok("the URL is https://Example.com/Repo?x=1#Frag".into()), + ); + + assert_eq!(finalized.text, "https://Example.com/Repo?x=1#Frag"); + assert!(finalized.rewrite_summary.rewrite_used); + } + + #[test] + fn prefixed_structured_literal_still_canonicalizes_meta_wrapped_rewrite_output() { + let config = plan_config(PostprocessMode::Rewrite, RewriteBackend::Cloud); + let transcript = Transcript { + raw_text: "/api/v1".into(), + detected_language: Some("en".into()), + segments: Vec::new(), + }; + let plan = planning::build_rewrite_plan( + &config, + &planning::load_runtime_text_resources(&config), + &transcript, + None, + None, + ); + + let finalized = finalize_rewrite_attempt(&config, plan, Ok("URL: /api/v1".into())); + + assert_eq!(finalized.text, "/api/v1"); + assert!(finalized.rewrite_summary.rewrite_used); + } + + #[test] + fn structured_literal_output_is_canonicalized_when_accepted() { + let config = plan_config(PostprocessMode::Rewrite, RewriteBackend::Cloud); + let mut plan = rewrite_plan(); + plan.rewrite_transcript.raw_text = "portfolio. Notes. Supply".into(); + plan.rewrite_transcript.correction_aware_text = "portfolio. Notes. Supply".into(); + plan.rewrite_transcript.rewrite_candidates = vec![RewriteCandidate { + kind: RewriteCandidateKind::StructuredLiteral, + text: "portfolio.notes.supply".into(), + }]; + + let finalized = + finalize_rewrite_attempt(&config, plan, Ok("portfolio . notes . supply".into())); + + assert_eq!(finalized.text, "portfolio.notes.supply"); + assert!(finalized.rewrite_summary.rewrite_used); + } + + #[test] + fn structured_literal_embedded_sentence_rejects_lossy_candidate_only_output() { + let config = plan_config(PostprocessMode::Rewrite, RewriteBackend::Cloud); + let mut plan = rewrite_plan(); + plan.fallback_text = "Check portfolio. Notes. Supply tomorrow".into(); + plan.rewrite_transcript.rewrite_candidates = vec![RewriteCandidate { + kind: RewriteCandidateKind::StructuredLiteral, + text: "portfolio.notes.supply".into(), + }]; + + let finalized = + finalize_rewrite_attempt(&config, plan, Ok("portfolio.notes.supply".into())); + + assert_eq!(finalized.text, "Check portfolio. Notes. Supply tomorrow"); + assert!(!finalized.rewrite_summary.rewrite_used); + } + + #[test] + fn structured_literal_embedded_sentence_accepts_full_sentence_rewrite() { + let config = plan_config(PostprocessMode::Rewrite, RewriteBackend::Cloud); + let mut plan = rewrite_plan(); + plan.fallback_text = "Check portfolio. Notes. Supply tomorrow".into(); + plan.rewrite_transcript.rewrite_candidates = vec![RewriteCandidate { + kind: RewriteCandidateKind::StructuredLiteral, + text: "portfolio.notes.supply".into(), + }]; + + let finalized = finalize_rewrite_attempt( + &config, + plan, + Ok("Check portfolio.notes.supply tomorrow.".into()), + ); + + assert_eq!(finalized.text, "Check portfolio.notes.supply tomorrow."); + assert!(finalized.rewrite_summary.rewrite_used); + } } diff --git a/src/postprocess/planning.rs b/src/postprocess/planning.rs index c1b23a8..30d0a76 100644 --- a/src/postprocess/planning.rs +++ b/src/postprocess/planning.rs @@ -6,12 +6,21 @@ use crate::config::{Config, PostprocessMode}; use crate::context::TypingContext; use crate::personalization::{self, PersonalizationRules}; use crate::rewrite_model; -use crate::rewrite_protocol::{RewriteSessionBacktrackCandidateKind, RewriteTranscript}; +use crate::rewrite_protocol::{ + RewriteCandidateKind, RewriteSessionBacktrackCandidateKind, RewriteTranscript, +}; use crate::session::{self, EligibleSessionEntry}; +use crate::structured_text; use crate::transcribe::Transcript; use super::finalize::FinalizedOperation; +#[derive(Debug, Clone, Default)] +pub struct RuntimeTextResources { + pub(crate) rules: PersonalizationRules, + pub(crate) runtime_policy: agentic_rewrite::RuntimePolicyResources, +} + pub(crate) struct RewritePlan { pub rules: PersonalizationRules, pub fallback_text: String, @@ -35,29 +44,70 @@ pub(crate) fn resolve_rewrite_model_path(config: &Config) -> Option { rewrite_model::selected_model_path(&config.rewrite.selected_model) } -pub(crate) fn load_runtime_rules(config: &Config) -> PersonalizationRules { +pub(crate) fn load_runtime_rules_with_status(config: &Config) -> (PersonalizationRules, bool) { match personalization::load_rules(config) { - Ok(rules) => rules, + Ok(rules) => (rules, false), Err(err) => { tracing::warn!("failed to load personalization rules: {err}"); - PersonalizationRules::default() + (PersonalizationRules::default(), true) } } } +pub fn load_runtime_text_resources(config: &Config) -> RuntimeTextResources { + load_runtime_text_resources_with_status(config).0 +} + +pub fn load_runtime_text_resources_with_status(config: &Config) -> (RuntimeTextResources, bool) { + let (rules, rules_degraded) = load_runtime_rules_with_status(config); + let (runtime_policy, policy_degraded) = + agentic_rewrite::load_runtime_resources_with_status(config); + + ( + RuntimeTextResources { + rules, + runtime_policy, + }, + rules_degraded || policy_degraded, + ) +} + pub(crate) fn build_rewrite_plan( config: &Config, + resources: &RuntimeTextResources, transcript: &Transcript, typing_context: Option<&TypingContext>, recent_session: Option<&EligibleSessionEntry>, ) -> RewritePlan { - let rules = load_runtime_rules(config); + let rules = resources.rules.clone(); let local_model_path = resolve_rewrite_model_path(config); let mut rewrite_transcript = personalization::build_rewrite_transcript(transcript, &rules); rewrite_transcript.typing_context = typing_context.and_then(session::to_rewrite_typing_context); - agentic_rewrite::apply_runtime_policy(config, &mut rewrite_transcript); + agentic_rewrite::apply_runtime_policy_with_resources( + config, + &mut rewrite_transcript, + &resources.runtime_policy, + ); let session_plan = session::build_backtrack_plan(&rewrite_transcript, recent_session); let mut fallback_text = base_text(config, transcript); + if let Some(candidate) = rewrite_transcript + .rewrite_candidates + .iter() + .find(|candidate| candidate.kind == RewriteCandidateKind::StructuredLiteral) + { + let candidate_text = candidate.text.as_str(); + let prefer_structured_fallback = [ + Some(transcript.raw_text.as_str()), + Some(rewrite_transcript.correction_aware_text.as_str()), + rewrite_transcript.aggressive_correction_text.as_deref(), + ] + .into_iter() + .flatten() + .any(|text| structured_text::output_matches_candidate(text, candidate_text)); + if prefer_structured_fallback { + fallback_text = candidate.text.clone(); + } + } if session_plan.recommended.as_ref().is_some_and(|candidate| { matches!( candidate.kind, @@ -145,7 +195,7 @@ fn recommended_operation(rewrite_transcript: &RewriteTranscript) -> FinalizedOpe #[cfg(test)] mod tests { - use super::build_rewrite_plan; + use super::{build_rewrite_plan, load_runtime_text_resources}; use crate::config::{Config, PostprocessMode}; use crate::context::SurfaceKind; use crate::postprocess::finalize::FinalizedOperation; @@ -181,7 +231,13 @@ mod tests { delete_graphemes: 11, }; - let plan = build_rewrite_plan(&config, &transcript, None, Some(&recent)); + let plan = build_rewrite_plan( + &config, + &load_runtime_text_resources(&config), + &transcript, + None, + Some(&recent), + ); assert_eq!(plan.fallback_text, "Hi"); assert_eq!(plan.recommended_candidate.as_deref(), Some("Hi")); assert_eq!( @@ -192,4 +248,49 @@ mod tests { } ); } + + #[test] + fn build_rewrite_plan_prefers_structured_literal_for_fallback() { + let mut config = Config::default(); + config.postprocess.mode = PostprocessMode::Rewrite; + + let transcript = Transcript { + raw_text: "portfolio. Notes. Supply is the URL".into(), + detected_language: Some("en".into()), + segments: Vec::new(), + }; + + let plan = build_rewrite_plan( + &config, + &load_runtime_text_resources(&config), + &transcript, + None, + None, + ); + assert_eq!(plan.fallback_text, "portfolio.notes.supply"); + } + + #[test] + fn build_rewrite_plan_keeps_full_fallback_when_structured_text_is_embedded() { + let mut config = Config::default(); + config.postprocess.mode = PostprocessMode::Rewrite; + + let transcript = Transcript { + raw_text: "Check portfolio. Notes. Supply tomorrow".into(), + detected_language: Some("en".into()), + segments: Vec::new(), + }; + + let plan = build_rewrite_plan( + &config, + &load_runtime_text_resources(&config), + &transcript, + None, + None, + ); + assert_eq!( + plan.fallback_text, + "Check portfolio. Notes. Supply tomorrow" + ); + } } diff --git a/src/rewrite/prompt.rs b/src/rewrite/prompt.rs index 5da5c69..a328ee4 100644 --- a/src/rewrite/prompt.rs +++ b/src/rewrite/prompt.rs @@ -145,7 +145,9 @@ word. Use nearby category words like window manager, editor, language, library, tool to disambiguate technical names. When a dictated word is an obvious phonetic near-miss for a likely technical term \ and the surrounding context clearly identifies the category, correct it to the canonical technical spelling instead of \ echoing the miss. If multiple plausible interpretations remain similarly credible, stay close to the transcript rather \ -than inventing a niche term. \ +than inventing a niche term. When the utterance is a hostname, URL, email address, or other structured text, preserve \ +dots, slashes, colons, dashes, underscores, and at-signs literally. Do not turn structured labels into sentence \ +punctuation and do not append explanatory prose such as saying something is a URL. \ If an edit intent says to replace or cancel previous wording, preserve that edit when the utterance or same-session \ context clearly supports it. Preserve utterance-initial courtesy or apology wording when the raw transcript still \ clearly intends it. Examples:\n\ @@ -179,7 +181,12 @@ phonetically similar common word. Use nearby category words like window manager, manager, shell, or terminal tool to disambiguate technical names. When a dictated word is an obvious phonetic near-miss \ for a likely technical term and the surrounding context clearly identifies the category, correct it to the canonical \ technical spelling instead of echoing the miss. If multiple plausible interpretations remain similarly credible, stay \ -close to the transcript rather than inventing a niche term. If an edit intent says to replace or cancel previous wording, preserve that edit when the utterance or same-session context clearly supports it. Preserve utterance-initial courtesy or apology wording when the raw transcript still clearly intends it. Examples:\n\ +close to the transcript rather than inventing a niche term. When the utterance is a hostname, URL, email address, or \ +other structured text, preserve dots, slashes, colons, dashes, underscores, and at-signs literally. Do not turn \ +structured labels into sentence punctuation and do not append explanatory prose such as saying something is a URL. If \ +an edit intent says to replace or cancel previous wording, preserve that edit when the utterance or same-session \ +context clearly supports it. Preserve utterance-initial courtesy or apology wording when the raw transcript still \ +clearly intends it. Examples:\n\ - raw: Hello there. Scratch that. Hi.\n correction-aware: Hi.\n final: Hi.\n\ - raw: I'll bring cookies, scratch that, brownies.\n correction-aware: I'll bring brownies.\n final: I'll bring brownies.\n\ - raw: My name is Notes, scratch that my name is Jonatan.\n correction-aware: My my name is Jonatan.\n aggressive correction-aware: My name is Jonatan.\n final: My name is Jonatan.\n\ @@ -324,7 +331,8 @@ Structured cue context:\n\ Self-corrections were already resolved before rewriting.\n\ Use this correction-aware transcript as the main source text. In agentic mode, you may still normalize likely \ technical terms or proper names when the utterance strongly supports them, even if the exact canonical spelling is not \ -already present in the candidate list:\n\ +already present in the candidate list. When a structured-text candidate is present, preserve its punctuation literally \ +and do not rewrite it into prose:\n\ {correction_aware}\n\ {agentic_candidates}\ Do not restore any canceled wording from earlier in the utterance.\n\ @@ -346,7 +354,8 @@ Correction-aware transcript:\n\ {correction_aware}\n\ Treat the correction-aware transcript as authoritative for explicit spoken edits and overall meaning, but in agentic \ mode you may normalize likely technical terms or proper names when category cues in the utterance make the intended \ -technical meaning clearly better than the literal transcript.\n\ +technical meaning clearly better than the literal transcript. When a structured-text candidate is present, preserve \ +its punctuation literally and do not rewrite it into prose.\n\ {agentic_candidates}\ \ Recent segments:\n\ @@ -559,6 +568,9 @@ fn render_rewrite_candidates(transcript: &RewriteTranscript) -> String { crate::rewrite_protocol::RewriteCandidateKind::Literal => { "literal (keep only if the cue was not actually an edit)" } + crate::rewrite_protocol::RewriteCandidateKind::StructuredLiteral => { + "structured_literal (preserve structured punctuation literally)" + } crate::rewrite_protocol::RewriteCandidateKind::ConservativeCorrection => { "conservative_correction (balanced cleanup)" } diff --git a/src/rewrite/tests.rs b/src/rewrite/tests.rs index cdb7930..b9d8206 100644 --- a/src/rewrite/tests.rs +++ b/src/rewrite/tests.rs @@ -181,6 +181,14 @@ fn base_instructions_allow_technical_term_inference() { assert!(instructions.contains("phonetically similar common word")); } +#[test] +fn instructions_cover_structured_text_literals() { + let instructions = rewrite_instructions(ResolvedRewriteProfile::LlamaCompat); + assert!(instructions.contains("hostname, URL, email address")); + assert!(instructions.contains("preserve dots, slashes, colons")); + assert!(instructions.contains("do not append explanatory prose")); +} + #[test] fn custom_instructions_append_to_system_prompt() { let instructions = build_system_instructions( @@ -240,6 +248,22 @@ fn fast_route_prompt_allows_agentic_technical_normalization() { ); } +#[test] +fn fast_route_prompt_mentions_preserving_structured_candidates() { + let mut transcript = fast_agentic_transcript(); + transcript.rewrite_candidates.insert( + 0, + RewriteCandidate { + kind: RewriteCandidateKind::StructuredLiteral, + text: "portfolio.notes.supply".into(), + }, + ); + + let prompt = build_user_message(&transcript); + assert!(prompt.contains("structured-text candidate is present")); + assert!(prompt.contains("do not rewrite it into prose")); +} + #[test] fn cue_prompt_includes_raw_candidate_and_signals() { let prompt = build_user_message(&correction_transcript()); diff --git a/src/rewrite_protocol.rs b/src/rewrite_protocol.rs index 33b13ff..cbd05b8 100644 --- a/src/rewrite_protocol.rs +++ b/src/rewrite_protocol.rs @@ -192,6 +192,7 @@ pub enum RewriteCorrectionPolicy { #[serde(rename_all = "snake_case")] pub enum RewriteCandidateKind { Literal, + StructuredLiteral, ConservativeCorrection, AggressiveCorrection, GlossaryCorrection, diff --git a/src/rewrite_worker.rs b/src/rewrite_worker.rs index a8fc518..2bf65f0 100644 --- a/src/rewrite_worker.rs +++ b/src/rewrite_worker.rs @@ -80,31 +80,36 @@ impl RewriteService { self.spawn_worker() } - async fn ensure_running(&self, timeout: Duration) -> Result<()> { + async fn ensure_running( + &self, + deadline: tokio::time::Instant, + timeout: Duration, + ) -> Result<()> { if self.is_running() { return Ok(()); } self.prewarm()?; - let deadline = tokio::time::Instant::now() + timeout; loop { - match UnixStream::connect(&self.socket_path).await { - Ok(stream) => { + let remaining = remaining_rewrite_budget(deadline, timeout)?; + match tokio::time::timeout(remaining, UnixStream::connect(&self.socket_path)).await { + Ok(Ok(stream)) => { drop(stream); return Ok(()); } - Err(err) if tokio::time::Instant::now() < deadline => { + Ok(Err(err)) if tokio::time::Instant::now() < deadline => { let _ = err; - tokio::time::sleep(Duration::from_millis(50)).await; + tokio::time::sleep(Duration::from_millis(50).min(remaining)).await; } - Err(err) => { + Ok(Err(err)) => { return Err(WhsprError::Rewrite(format!( "rewrite worker at {} did not become ready: {err}", self.socket_path.display() ))); } - } + Err(_) => return Err(rewrite_timeout_error(timeout)), + }; } } @@ -143,62 +148,60 @@ impl RewriteService { } } -pub async fn rewrite_with_service( - service: &RewriteService, - config: &RewriteConfig, - transcript: &RewriteTranscript, - custom_instructions: Option<&str>, -) -> Result { - let timeout = Duration::from_millis(config.timeout_ms); - tracing::trace!( - candidates = transcript.rewrite_candidates.len(), - hypotheses = transcript.edit_hypotheses.len(), - has_recommended = transcript.recommended_candidate.is_some(), - "sending rewrite request to worker" - ); - service.ensure_running(timeout).await?; +fn remaining_rewrite_budget(deadline: tokio::time::Instant, timeout: Duration) -> Result { + deadline + .checked_duration_since(tokio::time::Instant::now()) + .ok_or_else(|| rewrite_timeout_error(timeout)) +} - let mut stream = tokio::time::timeout(timeout, UnixStream::connect(&service.socket_path)) - .await - .map_err(|_| { - WhsprError::Rewrite(format!( - "rewrite worker timed out after {}ms", - timeout.as_millis() - )) - })? - .map_err(|e| { - WhsprError::Rewrite(format!( - "failed to connect to rewrite worker at {}: {e}", - service.socket_path.display() - )) - })?; +fn rewrite_timeout_error(timeout: Duration) -> WhsprError { + WhsprError::Rewrite(format!( + "rewrite worker timed out after {}ms", + timeout.as_millis() + )) +} - let mut payload = serde_json::to_vec(&WorkerRequest::Rewrite { - transcript: transcript.clone(), - custom_instructions: custom_instructions.map(str::to_owned), - }) - .map_err(|e| WhsprError::Rewrite(format!("failed to encode rewrite request: {e}")))?; - payload.push(b'\n'); - stream - .write_all(&payload) - .await - .map_err(|e| WhsprError::Rewrite(format!("failed to send rewrite request: {e}")))?; - stream - .flush() +async fn rewrite_request( + service: &RewriteService, + payload: &[u8], + deadline: tokio::time::Instant, + timeout: Duration, +) -> Result { + let mut stream = tokio::time::timeout( + remaining_rewrite_budget(deadline, timeout)?, + UnixStream::connect(&service.socket_path), + ) + .await + .map_err(|_| rewrite_timeout_error(timeout))? + .map_err(|e| { + WhsprError::Rewrite(format!( + "failed to connect to rewrite worker at {}: {e}", + service.socket_path.display() + )) + })?; + + tokio::time::timeout( + remaining_rewrite_budget(deadline, timeout)?, + stream.write_all(payload), + ) + .await + .map_err(|_| rewrite_timeout_error(timeout))? + .map_err(|e| WhsprError::Rewrite(format!("failed to send rewrite request: {e}")))?; + + tokio::time::timeout(remaining_rewrite_budget(deadline, timeout)?, stream.flush()) .await + .map_err(|_| rewrite_timeout_error(timeout))? .map_err(|e| WhsprError::Rewrite(format!("failed to flush rewrite request: {e}")))?; let mut reader = BufReader::new(stream); let mut line = String::new(); - tokio::time::timeout(timeout, reader.read_line(&mut line)) - .await - .map_err(|_| { - WhsprError::Rewrite(format!( - "rewrite worker timed out after {}ms", - timeout.as_millis() - )) - })? - .map_err(|e| WhsprError::Rewrite(format!("failed to read rewrite response: {e}")))?; + tokio::time::timeout( + remaining_rewrite_budget(deadline, timeout)?, + reader.read_line(&mut line), + ) + .await + .map_err(|_| rewrite_timeout_error(timeout))? + .map_err(|e| WhsprError::Rewrite(format!("failed to read rewrite response: {e}")))?; if line.trim().is_empty() { return Err(WhsprError::Rewrite( @@ -220,6 +223,31 @@ pub async fn rewrite_with_service( } } +pub async fn rewrite_with_service( + service: &RewriteService, + config: &RewriteConfig, + transcript: &RewriteTranscript, + custom_instructions: Option<&str>, +) -> Result { + let timeout = Duration::from_millis(config.timeout_ms); + let deadline = tokio::time::Instant::now() + timeout; + tracing::trace!( + candidates = transcript.rewrite_candidates.len(), + hypotheses = transcript.edit_hypotheses.len(), + has_recommended = transcript.recommended_candidate.is_some(), + "sending rewrite request to worker" + ); + service.ensure_running(deadline, timeout).await?; + + let mut payload = serde_json::to_vec(&WorkerRequest::Rewrite { + transcript: transcript.clone(), + custom_instructions: custom_instructions.map(str::to_owned), + }) + .map_err(|e| WhsprError::Rewrite(format!("failed to encode rewrite request: {e}")))?; + payload.push(b'\n'); + rewrite_request(service, &payload, deadline, timeout).await +} + fn runtime_dir() -> PathBuf { std::env::var("XDG_RUNTIME_DIR") .map(PathBuf::from) @@ -300,7 +328,13 @@ impl Drop for StartupLock { #[cfg(test)] mod tests { use super::*; + use crate::error::WhsprError; use crate::rewrite_profile::RewriteProfile; + use crate::rewrite_protocol::{ + RewriteCorrectionPolicy, RewritePolicyContext, RewriteTranscript, + }; + use crate::test_support::unique_temp_dir; + use std::path::PathBuf; #[test] fn service_paths_change_when_profile_changes() { @@ -339,4 +373,77 @@ mod tests { let service = RewriteService::new(&config, Path::new("/models/custom.gguf")); assert_eq!(service.profile, ResolvedRewriteProfile::Qwen); } + + #[tokio::test] + async fn rewrite_with_service_times_out_when_request_write_stalls() { + let runtime_dir = unique_temp_dir("rewrite-request-timeout"); + let socket_path = runtime_dir.join("rewrite.sock"); + let listener = + tokio::net::UnixListener::bind(&socket_path).expect("bind stalled rewrite socket"); + + let server = tokio::spawn(async move { + let (probe, _) = listener.accept().await.expect("accept readiness probe"); + drop(probe); + let (_stalled_request, _) = listener.accept().await.expect("accept rewrite request"); + tokio::time::sleep(Duration::from_secs(1)).await; + }); + + let service = RewriteService { + socket_path: socket_path.clone(), + lock_path: runtime_dir.join("rewrite.lock"), + model_path: PathBuf::from("/tmp/test.gguf"), + profile: ResolvedRewriteProfile::Generic, + max_tokens: 256, + max_output_chars: 1200, + idle_timeout_ms: 0, + }; + let config = RewriteConfig { + timeout_ms: 50, + ..RewriteConfig::default() + }; + + let err = rewrite_with_service( + &service, + &config, + &oversized_transcript(2 * 1024 * 1024), + None, + ) + .await + .expect_err("stalled rewrite request should time out"); + let message = match err { + WhsprError::Rewrite(message) => message, + other => panic!("unexpected error: {other:?}"), + }; + assert!(message.contains("rewrite worker timed out")); + + server.abort(); + } + + fn oversized_transcript(size: usize) -> RewriteTranscript { + let text = "word ".repeat(size / 5); + RewriteTranscript { + raw_text: text.clone(), + correction_aware_text: text, + aggressive_correction_text: None, + detected_language: Some("en".into()), + typing_context: None, + recent_session_entries: Vec::new(), + session_backtrack_candidates: Vec::new(), + recommended_session_candidate: None, + segments: Vec::new(), + edit_intents: Vec::new(), + edit_signals: Vec::new(), + edit_hypotheses: Vec::new(), + rewrite_candidates: Vec::new(), + recommended_candidate: None, + edit_context: Default::default(), + policy_context: RewritePolicyContext { + correction_policy: RewriteCorrectionPolicy::Balanced, + matched_rule_names: Vec::new(), + effective_rule_instructions: Vec::new(), + active_glossary_terms: Vec::new(), + glossary_candidates: Vec::new(), + }, + } + } } diff --git a/src/runtime_diagnostics.rs b/src/runtime_diagnostics.rs new file mode 100644 index 0000000..375012b --- /dev/null +++ b/src/runtime_diagnostics.rs @@ -0,0 +1,894 @@ +use std::fmt::Write as _; +use std::path::{Path, PathBuf}; +use std::process::Command; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use serde::{Deserialize, Serialize}; + +use crate::config::{Config, RewriteBackend, TranscriptionBackend}; +use crate::runtime_guards::run_command_output_with_timeout; +use crate::safe_fs; + +const STATUS_FILE_NAME: &str = "main-status.json"; +const HANG_DEBUG_ENV: &str = "WHISPERS_HANG_DEBUG"; +const DEFAULT_PREPARE_HANG_TIMEOUT: Duration = Duration::from_secs(20); +const DEFAULT_TRANSCRIBE_HANG_TIMEOUT: Duration = Duration::from_secs(90); +const DEFAULT_POSTPROCESS_HANG_TIMEOUT: Duration = Duration::from_secs(15); +const DEFAULT_INJECT_HANG_TIMEOUT: Duration = Duration::from_secs(8); +const DEFAULT_SESSION_WRITE_HANG_TIMEOUT: Duration = Duration::from_secs(3); +const DEFAULT_WATCHDOG_POLL_INTERVAL: Duration = Duration::from_millis(250); +const DEFAULT_COMMAND_CAPTURE_TIMEOUT: Duration = Duration::from_secs(1); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum DictationStage { + Starting, + Recording, + AsrPrepare, + RecordingStopped, + AsrTranscribe, + Postprocess, + Inject, + SessionWrite, + Done, + Cancelled, +} + +impl DictationStage { + pub(crate) fn as_str(self) -> &'static str { + match self { + Self::Starting => "starting", + Self::Recording => "recording", + Self::AsrPrepare => "asr_prepare", + Self::RecordingStopped => "recording_stopped", + Self::AsrTranscribe => "asr_transcribe", + Self::Postprocess => "postprocess", + Self::Inject => "inject", + Self::SessionWrite => "session_write", + Self::Done => "done", + Self::Cancelled => "cancelled", + } + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] +pub(crate) struct DictationStageMetadata { + pub substage: Option, + pub detail: Option, + pub audio_samples: Option, + pub sample_rate: Option, + pub audio_duration_ms: Option, + pub transcript_chars: Option, + pub output_chars: Option, + pub operation: Option, + pub rewrite_used: Option, + pub degraded_reason: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub(crate) struct MainStatusSnapshot { + pub pid: u32, + pub started_at_ms: u64, + pub stage: String, + pub stage_started_at_ms: u64, + pub transcription_backend: String, + pub asr_model: String, + pub rewrite_backend: String, + pub rewrite_model: String, + pub metadata: DictationStageMetadata, +} + +#[derive(Debug)] +struct State { + pid: u32, + started_at_ms: u64, + stage: DictationStage, + stage_started_at_ms: u64, + transcription_backend: String, + asr_model: String, + rewrite_backend: String, + rewrite_model: String, + metadata: DictationStageMetadata, +} + +#[derive(Debug, Clone, Copy)] +struct HangWatchdogConfig { + enabled: bool, + prepare_timeout: Duration, + transcribe_timeout: Duration, + postprocess_timeout: Duration, + inject_timeout: Duration, + session_write_timeout: Duration, + poll_interval: Duration, + command_timeout: Duration, +} + +#[derive(Debug)] +struct WatchdogControl { + #[cfg(test)] + enabled: bool, + stop: AtomicBool, + handle: Mutex>>, +} + +#[derive(Clone, Debug)] +pub(crate) struct DictationRuntimeDiagnostics { + status_path: PathBuf, + state: Arc>, + ownership: Arc<()>, + watchdog: Arc, +} + +impl DictationRuntimeDiagnostics { + pub(crate) fn new(config: &Config) -> Self { + Self::build(config, hang_watchdog_config(config)) + } + + fn build(config: &Config, hang_watchdog: HangWatchdogConfig) -> Self { + let now = now_ms(); + let watchdog_enabled = hang_watchdog.enabled; + let diagnostics = Self { + status_path: status_file_path(), + state: Arc::new(Mutex::new(State { + pid: std::process::id(), + started_at_ms: now, + stage: DictationStage::Starting, + stage_started_at_ms: now, + transcription_backend: config.transcription.backend.as_str().to_string(), + asr_model: transcription_model_label(config), + rewrite_backend: config.rewrite.backend.as_str().to_string(), + rewrite_model: rewrite_model_label(config), + metadata: DictationStageMetadata::default(), + })), + ownership: Arc::new(()), + watchdog: Arc::new(WatchdogControl { + #[cfg(test)] + enabled: watchdog_enabled, + stop: AtomicBool::new(false), + handle: Mutex::new(None), + }), + }; + diagnostics.start_watchdog(HangWatchdogConfig { + enabled: watchdog_enabled, + ..hang_watchdog + }); + diagnostics.persist_snapshot(); + tracing::info!( + stage = DictationStage::Starting.as_str(), + "dictation stage entered" + ); + diagnostics + } + + pub(crate) fn enter_stage(&self, stage: DictationStage, metadata: DictationStageMetadata) { + self.transition(stage, metadata, false); + } + + pub(crate) fn clear_with_stage(&self, stage: DictationStage) { + self.transition(stage, DictationStageMetadata::default(), true); + } + + pub(crate) fn snapshot(&self) -> Option { + let state = self.state.lock().ok()?; + Some(snapshot_from_state(&state)) + } + + #[cfg(test)] + fn new_with_watchdog_config(config: &Config, hang_watchdog: HangWatchdogConfig) -> Self { + Self::build(config, hang_watchdog) + } + + #[cfg(test)] + fn watchdog_enabled(&self) -> bool { + self.watchdog.enabled + } + + fn transition( + &self, + stage: DictationStage, + metadata: DictationStageMetadata, + remove_after_write: bool, + ) { + let now = now_ms(); + let snapshot = match self.state.lock() { + Ok(mut state) => { + let previous_stage = state.stage; + let previous_elapsed_ms = now.saturating_sub(state.stage_started_at_ms); + tracing::info!( + stage = previous_stage.as_str(), + elapsed_ms = previous_elapsed_ms, + "dictation stage finished" + ); + + state.stage = stage; + state.stage_started_at_ms = now; + state.metadata = metadata; + let snapshot = snapshot_from_state(&state); + tracing::info!(stage = stage.as_str(), "dictation stage entered"); + snapshot + } + Err(_) => { + tracing::warn!("dictation diagnostics lock poisoned; skipping stage update"); + return; + } + }; + + self.persist_snapshot_value(&snapshot); + + if remove_after_write { + self.remove_status_file(); + } + } + + fn persist_snapshot(&self) { + let Some(snapshot) = self.snapshot() else { + tracing::warn!("failed to snapshot dictation diagnostics state"); + return; + }; + self.persist_snapshot_value(&snapshot); + } + + fn persist_snapshot_value(&self, snapshot: &MainStatusSnapshot) { + if let Some(parent) = self.status_path.parent() + && let Err(err) = std::fs::create_dir_all(parent) + { + tracing::warn!( + "failed to create dictation runtime directory {}: {err}", + parent.display() + ); + return; + } + + let encoded = match serde_json::to_vec_pretty(snapshot) { + Ok(encoded) => encoded, + Err(err) => { + tracing::warn!("failed to encode dictation runtime status: {err}"); + return; + } + }; + + if let Err(err) = safe_fs::write(&self.status_path, encoded) { + tracing::warn!( + "failed to write dictation runtime status {}: {err}", + self.status_path.display() + ); + } + } + + fn start_watchdog(&self, config: HangWatchdogConfig) { + if !config.enabled { + return; + } + + let state = Arc::clone(&self.state); + let status_path = self.status_path.clone(); + let stop = Arc::clone(&self.watchdog); + let handle = std::thread::spawn(move || { + let mut dumped_stage_started_at_ms = None::; + + loop { + if stop.stop.load(Ordering::Relaxed) { + break; + } + std::thread::sleep(config.poll_interval); + + let (stage, snapshot) = match state.lock() { + Ok(state) => (state.stage, snapshot_from_state(&state)), + Err(_) => break, + }; + + if matches!(stage, DictationStage::Done | DictationStage::Cancelled) { + break; + } + + let Some(timeout) = stage_timeout(stage, config) else { + dumped_stage_started_at_ms = None; + continue; + }; + + let elapsed = now_ms().saturating_sub(snapshot.stage_started_at_ms); + if elapsed < timeout.as_millis() as u64 { + continue; + } + + if dumped_stage_started_at_ms == Some(snapshot.stage_started_at_ms) { + continue; + } + + dumped_stage_started_at_ms = Some(snapshot.stage_started_at_ms); + if let Err(err) = + write_hang_bundle(&status_path, stage, &snapshot, config.command_timeout) + { + tracing::warn!("failed to write hang diagnostics bundle: {err}"); + } + } + }); + + match self.watchdog.handle.lock() { + Ok(mut guard) => { + *guard = Some(handle); + } + Err(_) => { + tracing::warn!("dictation watchdog lock poisoned; detaching watchdog thread"); + } + } + } + + fn remove_status_file(&self) { + if let Err(err) = std::fs::remove_file(&self.status_path) + && err.kind() != std::io::ErrorKind::NotFound + { + tracing::warn!( + "failed to remove dictation runtime status {}: {err}", + self.status_path.display() + ); + } + } +} + +impl Drop for DictationRuntimeDiagnostics { + fn drop(&mut self) { + if Arc::strong_count(&self.ownership) == 1 { + self.watchdog.stop.store(true, Ordering::Relaxed); + if let Ok(mut guard) = self.watchdog.handle.lock() + && let Some(handle) = guard.take() + { + let _ = handle.join(); + } + self.remove_status_file(); + } + } +} + +fn snapshot_from_state(state: &State) -> MainStatusSnapshot { + MainStatusSnapshot { + pid: state.pid, + started_at_ms: state.started_at_ms, + stage: state.stage.as_str().to_string(), + stage_started_at_ms: state.stage_started_at_ms, + transcription_backend: state.transcription_backend.clone(), + asr_model: state.asr_model.clone(), + rewrite_backend: state.rewrite_backend.clone(), + rewrite_model: state.rewrite_model.clone(), + metadata: state.metadata.clone(), + } +} + +fn transcription_model_label(config: &Config) -> String { + match config.transcription.backend { + TranscriptionBackend::Cloud => config.cloud.transcription.model.clone(), + _ => { + if !config.transcription.model_path.trim().is_empty() { + config.resolved_model_path().display().to_string() + } else { + config.transcription.selected_model.clone() + } + } + } +} + +fn hang_watchdog_config(_config: &Config) -> HangWatchdogConfig { + HangWatchdogConfig { + enabled: std::env::var(HANG_DEBUG_ENV) + .map(|value| value == "1") + .unwrap_or(false), + prepare_timeout: DEFAULT_PREPARE_HANG_TIMEOUT, + transcribe_timeout: DEFAULT_TRANSCRIBE_HANG_TIMEOUT, + postprocess_timeout: DEFAULT_POSTPROCESS_HANG_TIMEOUT, + inject_timeout: DEFAULT_INJECT_HANG_TIMEOUT, + session_write_timeout: DEFAULT_SESSION_WRITE_HANG_TIMEOUT, + poll_interval: DEFAULT_WATCHDOG_POLL_INTERVAL, + command_timeout: DEFAULT_COMMAND_CAPTURE_TIMEOUT, + } +} + +fn stage_timeout(stage: DictationStage, config: HangWatchdogConfig) -> Option { + match stage { + DictationStage::AsrPrepare => Some(config.prepare_timeout), + DictationStage::AsrTranscribe => Some(config.transcribe_timeout), + DictationStage::Postprocess => Some(config.postprocess_timeout), + DictationStage::Inject => Some(config.inject_timeout), + DictationStage::SessionWrite => Some(config.session_write_timeout), + _ => None, + } +} + +fn hang_bundle_path(status_path: &Path, pid: u32, stage: DictationStage) -> PathBuf { + status_path + .parent() + .unwrap_or_else(|| Path::new("/tmp")) + .join(format!("hang-{pid}-{}-{}.log", stage.as_str(), now_ms())) +} + +fn write_hang_bundle( + status_path: &Path, + stage: DictationStage, + snapshot: &MainStatusSnapshot, + command_timeout: Duration, +) -> std::io::Result { + let bundle_path = hang_bundle_path(status_path, snapshot.pid, stage); + let mut body = String::new(); + let _ = writeln!(body, "whispers hang diagnostics"); + let _ = writeln!(body, "pid: {}", snapshot.pid); + let _ = writeln!(body, "stage: {}", stage.as_str()); + let _ = writeln!(body, "status_path: {}", status_path.display()); + let _ = writeln!(body); + let _ = writeln!(body, "== main-status.json =="); + match serde_json::to_string_pretty(snapshot) { + Ok(json) => { + let _ = writeln!(body, "{json}"); + } + Err(err) => { + let _ = writeln!(body, "failed to encode status snapshot: {err}"); + } + } + let _ = writeln!(body); + body.push_str(&capture_stack_trace(snapshot.pid, command_timeout)); + body.push('\n'); + body.push_str(&capture_lsof(snapshot.pid, command_timeout)); + + safe_fs::write(&bundle_path, body)?; + tracing::warn!( + path = %bundle_path.display(), + stage = stage.as_str(), + "wrote dictation hang diagnostics bundle" + ); + Ok(bundle_path) +} + +fn capture_stack_trace(pid: u32, timeout: Duration) -> String { + let pid = pid.to_string(); + if command_available("gstack") { + let output = run_command("gstack", &[pid.as_str()], timeout); + if output.success { + return format_command_output("gstack", &output); + } + + let mut body = format_command_output("gstack", &output); + if command_available("gdb") { + body.push('\n'); + body.push_str(&format_command_output( + "gdb", + &run_command( + "gdb", + &["-batch", "-ex", "thread apply all bt", "-p", &pid], + timeout, + ), + )); + } + return body; + } + + if command_available("gdb") { + return format_command_output( + "gdb", + &run_command( + "gdb", + &["-batch", "-ex", "thread apply all bt", "-p", &pid], + timeout, + ), + ); + } + + "== stack trace ==\nno stack capture tool available; checked gstack and gdb\n".to_string() +} + +fn capture_lsof(pid: u32, timeout: Duration) -> String { + if !command_available("lsof") { + return "== lsof ==\nlsof not available\n".to_string(); + } + + format_command_output( + "lsof", + &run_command("lsof", &["-p", &pid.to_string()], timeout), + ) +} + +#[derive(Debug)] +struct CommandCapture { + success: bool, + stdout: String, + stderr: String, + error: Option, +} + +fn run_command(program: &str, args: &[&str], timeout: Duration) -> CommandCapture { + let mut command = Command::new(program); + command.args(args); + match run_command_output_with_timeout(&mut command, timeout) { + Ok(output) => CommandCapture { + success: output.status.success(), + stdout: String::from_utf8_lossy(&output.stdout).into_owned(), + stderr: String::from_utf8_lossy(&output.stderr).into_owned(), + error: None, + }, + Err(err) => CommandCapture { + success: false, + stdout: String::new(), + stderr: String::new(), + error: Some(err.to_string()), + }, + } +} + +fn format_command_output(program: &str, capture: &CommandCapture) -> String { + let mut body = String::new(); + let _ = writeln!(body, "== {program} =="); + if let Some(error) = capture.error.as_deref() { + let _ = writeln!(body, "failed to execute: {error}"); + return body; + } + let _ = writeln!(body, "success: {}", capture.success); + if !capture.stdout.trim().is_empty() { + let _ = writeln!(body, "-- stdout --"); + body.push_str(&capture.stdout); + if !capture.stdout.ends_with('\n') { + body.push('\n'); + } + } + if !capture.stderr.trim().is_empty() { + let _ = writeln!(body, "-- stderr --"); + body.push_str(&capture.stderr); + if !capture.stderr.ends_with('\n') { + body.push('\n'); + } + } + body +} + +fn command_available(program: &str) -> bool { + std::env::var_os("PATH").is_some_and(|path| { + std::env::split_paths(&path).any(|dir| { + let candidate = dir.join(program); + candidate.is_file() + }) + }) +} + +fn rewrite_model_label(config: &Config) -> String { + match config.rewrite.backend { + RewriteBackend::Cloud => config.cloud.rewrite.model.clone(), + RewriteBackend::Local => config + .resolved_rewrite_model_path() + .map(|path| path.display().to_string()) + .unwrap_or_else(|| config.rewrite.selected_model.clone()), + } +} + +fn status_file_path() -> PathBuf { + let runtime_dir = std::env::var("XDG_RUNTIME_DIR").unwrap_or_else(|_| "/tmp".into()); + PathBuf::from(runtime_dir) + .join("whispers") + .join(STATUS_FILE_NAME) +} + +fn now_ms() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_millis() as u64) + .unwrap_or(0) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::Config; + use crate::test_support::{EnvVarGuard, env_lock, set_env, unique_temp_dir}; + use std::ffi::CString; + use std::os::unix::ffi::OsStrExt; + use std::os::unix::fs::FileTypeExt; + use std::os::unix::fs::PermissionsExt; + + fn with_runtime_dir(f: impl FnOnce() -> T) -> T { + let _env_lock = env_lock(); + let _guard = EnvVarGuard::capture(&["PATH", "XDG_RUNTIME_DIR"]); + let runtime_dir = unique_temp_dir("runtime-diagnostics"); + let runtime_dir = runtime_dir + .to_str() + .expect("temp runtime dir should be valid UTF-8"); + set_env("XDG_RUNTIME_DIR", runtime_dir); + f() + } + + fn mkfifo(path: &Path) { + let c_path = CString::new(path.as_os_str().as_bytes()).expect("fifo path"); + let result = unsafe { libc::mkfifo(c_path.as_ptr(), 0o600) }; + assert_eq!( + result, + 0, + "mkfifo failed: {}", + std::io::Error::last_os_error() + ); + } + + #[test] + fn status_file_tracks_stage_updates_and_cleanup() { + with_runtime_dir(|| { + let diagnostics = DictationRuntimeDiagnostics::new(&Config::default()); + let status_path = status_file_path(); + assert!(status_path.exists()); + + let initial = std::fs::read_to_string(&status_path).expect("read status"); + let initial: MainStatusSnapshot = + serde_json::from_str(&initial).expect("parse initial status"); + assert_eq!(initial.stage, DictationStage::Starting.as_str()); + + diagnostics.enter_stage( + DictationStage::RecordingStopped, + DictationStageMetadata { + audio_samples: Some(32_000), + sample_rate: Some(16_000), + audio_duration_ms: Some(2_000), + ..DictationStageMetadata::default() + }, + ); + + let updated = std::fs::read_to_string(&status_path).expect("read status"); + let updated: MainStatusSnapshot = + serde_json::from_str(&updated).expect("parse updated status"); + assert_eq!(updated.stage, DictationStage::RecordingStopped.as_str()); + assert_eq!(updated.metadata.audio_samples, Some(32_000)); + + diagnostics.clear_with_stage(DictationStage::Done); + assert!(!status_path.exists()); + }); + } + + #[test] + fn status_file_removed_on_drop_without_explicit_completion() { + with_runtime_dir(|| { + let status_path = status_file_path(); + { + let diagnostics = DictationRuntimeDiagnostics::new(&Config::default()); + diagnostics.enter_stage( + DictationStage::AsrTranscribe, + DictationStageMetadata { + transcript_chars: Some(42), + ..DictationStageMetadata::default() + }, + ); + assert!(status_path.exists()); + } + + assert!(!status_path.exists()); + }); + } + + #[test] + fn status_file_write_skips_fifo_target() { + with_runtime_dir(|| { + let status_path = status_file_path(); + std::fs::create_dir_all(status_path.parent().expect("runtime dir")) + .expect("create runtime dir"); + mkfifo(&status_path); + + let started = std::time::Instant::now(); + let diagnostics = DictationRuntimeDiagnostics::new(&Config::default()); + diagnostics.enter_stage(DictationStage::Recording, DictationStageMetadata::default()); + + assert!( + started.elapsed() < Duration::from_secs(1), + "status snapshot writes should not block on FIFOs" + ); + assert!( + std::fs::metadata(&status_path) + .expect("status path metadata") + .file_type() + .is_fifo(), + "status snapshot target should remain the original FIFO" + ); + + diagnostics.clear_with_stage(DictationStage::Done); + assert!(!status_path.exists()); + }); + } + + #[test] + fn watchdog_arms_for_any_backend_when_enabled() { + with_runtime_dir(|| { + let mut cloud = Config::default(); + cloud.transcription.backend = TranscriptionBackend::Cloud; + let diagnostics = DictationRuntimeDiagnostics::new_with_watchdog_config( + &cloud, + HangWatchdogConfig { + enabled: true, + prepare_timeout: Duration::from_millis(20), + transcribe_timeout: Duration::from_millis(20), + postprocess_timeout: Duration::from_millis(20), + inject_timeout: Duration::from_millis(20), + session_write_timeout: Duration::from_millis(20), + poll_interval: Duration::from_millis(5), + command_timeout: Duration::from_millis(20), + }, + ); + assert!(diagnostics.watchdog_enabled()); + + let diagnostics = DictationRuntimeDiagnostics::new_with_watchdog_config( + &Config::default(), + HangWatchdogConfig { + enabled: true, + prepare_timeout: Duration::from_millis(20), + transcribe_timeout: Duration::from_millis(20), + postprocess_timeout: Duration::from_millis(20), + inject_timeout: Duration::from_millis(20), + session_write_timeout: Duration::from_millis(20), + poll_interval: Duration::from_millis(5), + command_timeout: Duration::from_millis(20), + }, + ); + assert!(diagnostics.watchdog_enabled()); + }); + } + + #[test] + fn watchdog_does_not_fire_after_advancing_to_untimed_stage() { + with_runtime_dir(|| { + let diagnostics = DictationRuntimeDiagnostics::new_with_watchdog_config( + &Config::default(), + HangWatchdogConfig { + enabled: true, + prepare_timeout: Duration::from_millis(40), + transcribe_timeout: Duration::from_millis(40), + postprocess_timeout: Duration::from_millis(40), + inject_timeout: Duration::from_millis(40), + session_write_timeout: Duration::from_millis(40), + poll_interval: Duration::from_millis(5), + command_timeout: Duration::from_millis(40), + }, + ); + diagnostics.enter_stage( + DictationStage::AsrPrepare, + DictationStageMetadata::default(), + ); + std::thread::sleep(Duration::from_millis(10)); + diagnostics.enter_stage( + DictationStage::RecordingStopped, + DictationStageMetadata { + transcript_chars: Some(12), + ..DictationStageMetadata::default() + }, + ); + std::thread::sleep(Duration::from_millis(80)); + + let runtime_dir = status_file_path() + .parent() + .expect("runtime dir") + .to_path_buf(); + let dump_count = std::fs::read_dir(runtime_dir) + .expect("read runtime dir") + .flatten() + .filter(|entry| entry.file_name().to_string_lossy().starts_with("hang-")) + .count(); + assert_eq!(dump_count, 0); + }); + } + + #[test] + fn watchdog_dump_includes_stage_metadata_and_command_output() { + with_runtime_dir(|| { + let bin_dir = unique_temp_dir("runtime-diagnostics-bin"); + let gstack_path = bin_dir.join("gstack"); + let lsof_path = bin_dir.join("lsof"); + std::fs::write(&gstack_path, "#!/bin/sh\necho \"fake gstack $@\"\n") + .expect("write gstack"); + std::fs::write(&lsof_path, "#!/bin/sh\necho \"fake lsof $@\"\n").expect("write lsof"); + std::fs::set_permissions(&gstack_path, std::fs::Permissions::from_mode(0o755)) + .expect("chmod gstack"); + std::fs::set_permissions(&lsof_path, std::fs::Permissions::from_mode(0o755)) + .expect("chmod lsof"); + let original_path = std::env::var("PATH").unwrap_or_default(); + let path = if original_path.is_empty() { + bin_dir.display().to_string() + } else { + format!("{}:{original_path}", bin_dir.display()) + }; + set_env("PATH", &path); + + let diagnostics = DictationRuntimeDiagnostics::new_with_watchdog_config( + &Config::default(), + HangWatchdogConfig { + enabled: true, + prepare_timeout: Duration::from_millis(30), + transcribe_timeout: Duration::from_millis(30), + postprocess_timeout: Duration::from_millis(30), + inject_timeout: Duration::from_millis(30), + session_write_timeout: Duration::from_millis(30), + poll_interval: Duration::from_millis(5), + command_timeout: Duration::from_secs(1), + }, + ); + diagnostics.enter_stage( + DictationStage::AsrPrepare, + DictationStageMetadata { + audio_samples: Some(4_096), + sample_rate: Some(16_000), + audio_duration_ms: Some(256), + ..DictationStageMetadata::default() + }, + ); + + let runtime_dir = status_file_path() + .parent() + .expect("runtime dir") + .to_path_buf(); + let deadline = std::time::Instant::now() + Duration::from_secs(2); + let dump_path = loop { + let maybe_path = std::fs::read_dir(&runtime_dir) + .expect("read runtime dir") + .flatten() + .map(|entry| entry.path()) + .find(|path| { + path.file_name() + .is_some_and(|name| name.to_string_lossy().starts_with("hang-")) + }); + if let Some(path) = maybe_path { + break path; + } + assert!( + std::time::Instant::now() < deadline, + "watchdog did not emit dump" + ); + std::thread::sleep(Duration::from_millis(10)); + }; + + let dump = std::fs::read_to_string(&dump_path).expect("read dump"); + assert!(dump.contains("\"stage\": \"asr_prepare\"")); + assert!(dump.contains("\"audio_samples\": 4096")); + assert!(dump.contains("fake gstack")); + assert!(dump.contains("fake lsof")); + }); + } + + #[test] + fn watchdog_covers_postprocess_stage() { + with_runtime_dir(|| { + let diagnostics = DictationRuntimeDiagnostics::new_with_watchdog_config( + &Config::default(), + HangWatchdogConfig { + enabled: true, + prepare_timeout: Duration::from_secs(1), + transcribe_timeout: Duration::from_secs(1), + postprocess_timeout: Duration::from_millis(30), + inject_timeout: Duration::from_secs(1), + session_write_timeout: Duration::from_secs(1), + poll_interval: Duration::from_millis(5), + command_timeout: Duration::from_millis(20), + }, + ); + diagnostics.enter_stage( + DictationStage::Postprocess, + DictationStageMetadata { + substage: Some("planning".into()), + transcript_chars: Some(12), + ..DictationStageMetadata::default() + }, + ); + + let runtime_dir = status_file_path() + .parent() + .expect("runtime dir") + .to_path_buf(); + let deadline = std::time::Instant::now() + Duration::from_secs(2); + loop { + let found = std::fs::read_dir(&runtime_dir) + .expect("read runtime dir") + .flatten() + .any(|entry| { + entry + .file_name() + .to_string_lossy() + .contains(&format!("{}-", DictationStage::Postprocess.as_str())) + }); + if found { + break; + } + assert!( + std::time::Instant::now() < deadline, + "watchdog did not emit dump" + ); + std::thread::sleep(Duration::from_millis(10)); + } + }); + } +} diff --git a/src/runtime_guards.rs b/src/runtime_guards.rs new file mode 100644 index 0000000..f8f375c --- /dev/null +++ b/src/runtime_guards.rs @@ -0,0 +1,393 @@ +use std::io::{self, Read}; +#[cfg(unix)] +use std::os::fd::{AsRawFd, RawFd}; +use std::process::{Child, Command, ExitStatus, Output, Stdio}; +use std::thread; +use std::time::{Duration, Instant}; + +#[cfg(unix)] +use std::os::unix::process::CommandExt; + +const POLL_INTERVAL: Duration = Duration::from_millis(10); +const KILL_WAIT_TIMEOUT: Duration = Duration::from_millis(100); + +pub(crate) fn wait_child_with_timeout( + child: &mut Child, + timeout: Duration, +) -> io::Result> { + let deadline = Instant::now() + timeout; + loop { + if let Some(status) = child.try_wait()? { + return Ok(Some(status)); + } + if Instant::now() >= deadline { + return Ok(None); + } + thread::sleep(POLL_INTERVAL.min(deadline.saturating_duration_since(Instant::now()))); + } +} + +pub(crate) fn run_command_output_with_timeout( + command: &mut Command, + timeout: Duration, +) -> io::Result { + configure_command_for_timeout(command); + command + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + let mut child = command.spawn()?; + + #[cfg(unix)] + { + run_command_output_with_timeout_unix(&mut child, timeout) + } + + #[cfg(not(unix))] + { + run_command_output_with_timeout_fallback(&mut child, timeout) + } +} + +#[cfg(unix)] +fn run_command_output_with_timeout_unix( + child: &mut Child, + timeout: Duration, +) -> io::Result { + let mut stdout = child.stdout.take().map(NonBlockingPipe::new).transpose()?; + let mut stderr = child.stderr.take().map(NonBlockingPipe::new).transpose()?; + let deadline = Instant::now() + timeout; + let mut cleanup_deadline = None::; + let mut status = None::; + + loop { + if status.is_none() { + status = child.try_wait()?; + } + + if let Some(pipe) = stdout.as_mut() { + pipe.drain()?; + } + if let Some(pipe) = stderr.as_mut() { + pipe.drain()?; + } + + let pipes_closed = stdout.as_ref().is_none_or(NonBlockingPipe::is_closed) + && stderr.as_ref().is_none_or(NonBlockingPipe::is_closed); + if let Some(status) = status + && pipes_closed + { + if cleanup_deadline.is_some() { + return timed_out_error(timeout); + } + return Ok(Output { + status, + stdout: stdout.map_or_else(Vec::new, NonBlockingPipe::into_bytes), + stderr: stderr.map_or_else(Vec::new, NonBlockingPipe::into_bytes), + }); + } + + let now = Instant::now(); + if cleanup_deadline.is_none() && now >= deadline { + kill_child_with_descendants(child); + let _ = wait_child_with_timeout(child, KILL_WAIT_TIMEOUT); + cleanup_deadline = Some(now + KILL_WAIT_TIMEOUT); + } else if cleanup_deadline.is_some_and(|limit| now >= limit) { + return timed_out_error(timeout); + } + + let sleep_until = cleanup_deadline.unwrap_or(deadline); + thread::sleep(POLL_INTERVAL.min(sleep_until.saturating_duration_since(now))); + } +} + +#[cfg(not(unix))] +fn run_command_output_with_timeout_fallback( + child: &mut Child, + timeout: Duration, +) -> io::Result { + let stdout_reader = spawn_pipe_reader(child.stdout.take()); + let stderr_reader = spawn_pipe_reader(child.stderr.take()); + let status = match wait_child_with_timeout(child, timeout)? { + Some(status) => status, + None => { + kill_child_with_descendants(child); + let _ = wait_child_with_timeout(child, KILL_WAIT_TIMEOUT); + let _ = stdout_reader.join(); + let _ = stderr_reader.join(); + return timed_out_error(timeout); + } + }; + + let stdout = stdout_reader.join().unwrap_or_default(); + let stderr = stderr_reader.join().unwrap_or_default(); + Ok(Output { + status, + stdout, + stderr, + }) +} + +pub(crate) fn run_command_status_with_timeout( + command: &mut Command, + timeout: Duration, +) -> io::Result { + configure_command_for_timeout(command); + let mut child = command.spawn()?; + match wait_child_with_timeout(&mut child, timeout)? { + Some(status) => Ok(status), + None => { + kill_child_with_descendants(&mut child); + let _ = wait_child_with_timeout(&mut child, KILL_WAIT_TIMEOUT); + timed_out_error(timeout) + } + } +} + +fn timed_out_error(timeout: Duration) -> io::Result { + Err(io::Error::new( + io::ErrorKind::TimedOut, + format!("command timed out after {}ms", timeout.as_millis()), + )) +} + +#[cfg(unix)] +struct NonBlockingPipe { + reader: R, + bytes: Vec, + closed: bool, +} + +#[cfg(unix)] +impl NonBlockingPipe +where + R: Read + AsRawFd, +{ + fn new(reader: R) -> io::Result { + set_nonblocking(reader.as_raw_fd())?; + Ok(Self { + reader, + bytes: Vec::new(), + closed: false, + }) + } + + fn drain(&mut self) -> io::Result<()> { + if self.closed { + return Ok(()); + } + + let mut chunk = [0_u8; 8192]; + loop { + match self.reader.read(&mut chunk) { + Ok(0) => { + self.closed = true; + return Ok(()); + } + Ok(read) => self.bytes.extend_from_slice(&chunk[..read]), + Err(err) if err.kind() == io::ErrorKind::WouldBlock => return Ok(()), + Err(err) if err.kind() == io::ErrorKind::Interrupted => continue, + Err(err) => return Err(err), + } + } + } + + fn is_closed(&self) -> bool { + self.closed + } + + fn into_bytes(self) -> Vec { + self.bytes + } +} + +#[cfg(unix)] +fn set_nonblocking(fd: RawFd) -> io::Result<()> { + let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) }; + if flags == -1 { + return Err(io::Error::last_os_error()); + } + + let result = unsafe { libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) }; + if result == -1 { + return Err(io::Error::last_os_error()); + } + + Ok(()) +} + +#[cfg(not(unix))] +fn spawn_pipe_reader(reader: Option) -> thread::JoinHandle> +where + R: Read + Send + 'static, +{ + thread::spawn(move || { + let mut bytes = Vec::new(); + if let Some(mut reader) = reader { + let _ = reader.read_to_end(&mut bytes); + } + bytes + }) +} + +fn configure_command_for_timeout(command: &mut Command) { + #[cfg(unix)] + command.process_group(0); +} + +fn kill_child_with_descendants(child: &mut Child) { + #[cfg(unix)] + { + // Timed commands are spawned into their own process group so shell + // wrappers and background descendants cannot keep inherited pipes open + // past the timeout. + let _ = unsafe { libc::killpg(child.id() as i32, libc::SIGKILL) }; + } + + #[cfg(not(unix))] + { + let _ = child.kill(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_support::unique_temp_dir; + + #[test] + fn command_output_timeout_captures_successful_output() { + let mut command = Command::new("/bin/sh"); + command.args(["-c", "printf 'hello'; printf 'warn' >&2"]); + let output = + run_command_output_with_timeout(&mut command, Duration::from_secs(1)).expect("output"); + assert!(output.status.success()); + assert_eq!(String::from_utf8_lossy(&output.stdout), "hello"); + assert_eq!(String::from_utf8_lossy(&output.stderr), "warn"); + } + + #[test] + fn command_output_timeout_kills_hung_process() { + let mut command = Command::new("/bin/sh"); + command.args(["-c", "/bin/sleep 5"]); + let err = run_command_output_with_timeout(&mut command, Duration::from_millis(20)) + .expect_err("timeout"); + assert_eq!(err.kind(), io::ErrorKind::TimedOut); + } + + #[test] + fn command_output_timeout_kills_shell_descendants_holding_pipes_open() { + let temp_dir = unique_temp_dir("runtime-guards-output-timeout"); + let pid_path = temp_dir.join("child.pid"); + let script = format!("/bin/sleep 5 & echo $! > '{}' ; wait", pid_path.display()); + let mut command = Command::new("/bin/sh"); + command.args(["-c", &script]); + + let started = Instant::now(); + let err = run_command_output_with_timeout(&mut command, Duration::from_millis(20)) + .expect_err("timeout"); + + assert_eq!(err.kind(), io::ErrorKind::TimedOut); + assert!( + started.elapsed() < Duration::from_secs(1), + "timeout path should return promptly" + ); + let child_pid = std::fs::read_to_string(&pid_path) + .expect("pid file") + .trim() + .parse::() + .expect("pid should parse"); + assert!( + process_is_gone(child_pid), + "timed command descendants should be terminated" + ); + } + + #[test] + fn command_output_timeout_applies_while_draining_pipes() { + let temp_dir = unique_temp_dir("runtime-guards-output-drain-timeout"); + let pid_path = temp_dir.join("child.pid"); + let script = format!("/bin/sleep 5 & echo $! > '{}' ; exit 0", pid_path.display()); + let mut command = Command::new("/bin/sh"); + command.args(["-c", &script]); + + let started = Instant::now(); + let err = run_command_output_with_timeout(&mut command, Duration::from_millis(20)) + .expect_err("timeout"); + + assert_eq!(err.kind(), io::ErrorKind::TimedOut); + assert!( + started.elapsed() < Duration::from_secs(1), + "drain timeout should return promptly" + ); + let child_pid = std::fs::read_to_string(&pid_path) + .expect("pid file") + .trim() + .parse::() + .expect("pid should parse"); + assert!( + process_is_gone(child_pid), + "drain timeout should terminate descendants holding inherited pipes" + ); + } + + #[test] + fn command_status_timeout_returns_exit_status() { + let mut command = Command::new("/bin/sh"); + command.args(["-c", "exit 7"]); + let status = + run_command_status_with_timeout(&mut command, Duration::from_secs(1)).expect("status"); + assert_eq!(status.code(), Some(7)); + } + + #[test] + fn command_status_timeout_kills_hung_process() { + let mut command = Command::new("/bin/sh"); + command.args(["-c", "/bin/sleep 5"]); + let err = run_command_status_with_timeout(&mut command, Duration::from_millis(20)) + .expect_err("timeout"); + assert_eq!(err.kind(), io::ErrorKind::TimedOut); + } + + #[test] + fn command_status_timeout_kills_shell_descendants() { + let temp_dir = unique_temp_dir("runtime-guards-status-timeout"); + let pid_path = temp_dir.join("child.pid"); + let script = format!("/bin/sleep 5 & echo $! > '{}' ; wait", pid_path.display()); + let mut command = Command::new("/bin/sh"); + command.args(["-c", &script]); + + let err = run_command_status_with_timeout(&mut command, Duration::from_millis(20)) + .expect_err("timeout"); + + assert_eq!(err.kind(), io::ErrorKind::TimedOut); + let child_pid = std::fs::read_to_string(&pid_path) + .expect("pid file") + .trim() + .parse::() + .expect("pid should parse"); + assert!( + process_is_gone(child_pid), + "timed command descendants should be terminated" + ); + } + + fn process_is_gone(pid: i32) -> bool { + #[cfg(unix)] + { + let result = unsafe { libc::kill(pid, 0) }; + if result == 0 { + false + } else { + let err = io::Error::last_os_error(); + err.raw_os_error() == Some(libc::ESRCH) + } + } + + #[cfg(not(unix))] + { + let _ = pid; + true + } + } +} diff --git a/src/safe_fs.rs b/src/safe_fs.rs new file mode 100644 index 0000000..5bc0a8e --- /dev/null +++ b/src/safe_fs.rs @@ -0,0 +1,170 @@ +use std::fs; +use std::io; +use std::path::Path; + +#[cfg(unix)] +use std::io::{Read, Write}; + +#[cfg(unix)] +use std::os::unix::fs::OpenOptionsExt; + +#[cfg(unix)] +use std::os::unix::fs::FileTypeExt; + +pub(crate) fn read_to_string(path: &Path) -> io::Result { + ensure_existing_regular_file(path, "read")?; + + #[cfg(unix)] + { + let mut file = open_for_read(path)?; + ensure_regular_file(path, &file.metadata()?.file_type(), "read")?; + let mut contents = String::new(); + file.read_to_string(&mut contents)?; + Ok(contents) + } + + #[cfg(not(unix))] + { + fs::read_to_string(path) + } +} + +pub(crate) fn write(path: &Path, contents: impl AsRef<[u8]>) -> io::Result<()> { + ensure_regular_file_or_missing(path, "write")?; + + #[cfg(unix)] + { + let mut file = open_for_write(path)?; + ensure_regular_file(path, &file.metadata()?.file_type(), "write")?; + file.write_all(contents.as_ref())?; + Ok(()) + } + + #[cfg(not(unix))] + { + fs::write(path, contents) + } +} + +fn ensure_existing_regular_file(path: &Path, operation: &str) -> io::Result<()> { + let metadata = fs::symlink_metadata(path)?; + ensure_regular_file(path, &metadata.file_type(), operation) +} + +fn ensure_regular_file_or_missing(path: &Path, operation: &str) -> io::Result<()> { + match fs::symlink_metadata(path) { + Ok(metadata) => ensure_regular_file(path, &metadata.file_type(), operation), + Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(()), + Err(err) => Err(err), + } +} + +fn ensure_regular_file(path: &Path, file_type: &fs::FileType, operation: &str) -> io::Result<()> { + if file_type.is_file() { + return Ok(()); + } + + Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!( + "cannot {operation} {}; expected a regular file, found {}", + path.display(), + describe_file_type(file_type) + ), + )) +} + +fn describe_file_type(file_type: &fs::FileType) -> &'static str { + if file_type.is_dir() { + return "directory"; + } + if file_type.is_symlink() { + return "symlink"; + } + + #[cfg(unix)] + { + if file_type.is_fifo() { + return "fifo"; + } + if file_type.is_socket() { + return "socket"; + } + if file_type.is_block_device() { + return "block device"; + } + if file_type.is_char_device() { + return "character device"; + } + } + + "special file" +} + +#[cfg(unix)] +fn open_for_read(path: &Path) -> io::Result { + let mut options = fs::OpenOptions::new(); + options + .read(true) + .custom_flags(libc::O_NOFOLLOW | libc::O_NONBLOCK); + options.open(path) +} + +#[cfg(unix)] +fn open_for_write(path: &Path) -> io::Result { + let mut options = fs::OpenOptions::new(); + options + .write(true) + .create(true) + .truncate(true) + .custom_flags(libc::O_NOFOLLOW | libc::O_NONBLOCK); + options.open(path) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_support::unique_temp_dir; + + #[cfg(unix)] + use std::os::unix::fs::symlink; + + #[test] + fn write_creates_regular_file() { + let dir = unique_temp_dir("safe-fs-write"); + let path = dir.join("status.json"); + + write(&path, b"ok").expect("write should create regular file"); + + assert_eq!(fs::read_to_string(path).expect("read written file"), "ok"); + } + + #[cfg(unix)] + #[test] + fn read_rejects_symlink_even_when_target_is_regular_file() { + let dir = unique_temp_dir("safe-fs-read-symlink"); + let target = dir.join("target.txt"); + let link = dir.join("link.txt"); + fs::write(&target, "secret").expect("write target"); + symlink(&target, &link).expect("create symlink"); + + let err = read_to_string(&link).expect_err("symlink read should be rejected"); + + assert_eq!(err.kind(), io::ErrorKind::InvalidInput); + } + + #[cfg(unix)] + #[test] + fn write_rejects_symlink_even_when_target_is_regular_file() { + let dir = unique_temp_dir("safe-fs-write-symlink"); + let target = dir.join("target.txt"); + let link = dir.join("link.txt"); + fs::write(&target, "seed").expect("write target"); + symlink(&target, &link).expect("create symlink"); + + let err = write(&link, b"updated").expect_err("symlink write should be rejected"); + + assert_eq!(err.kind(), io::ErrorKind::InvalidInput); + assert_eq!(fs::read_to_string(target).expect("read target"), "seed"); + } +} diff --git a/src/session/persistence.rs b/src/session/persistence.rs index 51d111f..6140a5e 100644 --- a/src/session/persistence.rs +++ b/src/session/persistence.rs @@ -7,6 +7,7 @@ use unicode_segmentation::UnicodeSegmentation; use crate::config::SessionConfig; use crate::context::TypingContext; use crate::error::{Result, WhsprError}; +use crate::safe_fs; use super::{EligibleSessionEntry, SessionEntry, SessionRewriteSummary}; @@ -114,7 +115,7 @@ fn load_session_file() -> Result { return Ok(SessionFile::default()); } - let contents = std::fs::read_to_string(&path).map_err(|e| { + let contents = safe_fs::read_to_string(&path).map_err(|e| { WhsprError::Config(format!( "failed to read session state {}: {e}", path.display() @@ -140,7 +141,7 @@ fn persist_session_file(state: &SessionFile) -> Result<()> { } let encoded = serde_json::to_vec(state) .map_err(|e| WhsprError::Config(format!("failed to encode session state: {e}")))?; - std::fs::write(&path, encoded).map_err(|e| { + safe_fs::write(&path, encoded).map_err(|e| { WhsprError::Config(format!( "failed to write session state {}: {e}", path.display() diff --git a/src/setup/tests.rs b/src/setup/tests.rs index 4c27fbd..e688392 100644 --- a/src/setup/tests.rs +++ b/src/setup/tests.rs @@ -1,9 +1,10 @@ use crate::config::Config; use super::{CloudSetup, apply}; -use crate::config::{ - self, RewriteBackend, RewriteFallback, TranscriptionBackend, TranscriptionFallback, -}; +use crate::config::{self, TranscriptionBackend, TranscriptionFallback}; + +#[cfg(not(feature = "local-rewrite"))] +use crate::config::{RewriteBackend, RewriteFallback}; #[test] fn runtime_selection_resets_cloud_asr_when_disabled() { diff --git a/src/structured_text.rs b/src/structured_text.rs new file mode 100644 index 0000000..02851ae --- /dev/null +++ b/src/structured_text.rs @@ -0,0 +1,642 @@ +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct StructuredCandidate { + pub normalized: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum Token { + Word { + text: String, + normalized: String, + joined_to_prev: bool, + joined_to_next: bool, + }, + Sep { + ch: char, + spoken: bool, + }, +} + +#[derive(Debug, Clone)] +struct ParsedCandidate { + normalized: String, + end: usize, + word_count: usize, + dot_clusters: usize, + has_non_dot_cluster: bool, + has_spoken_separator: bool, +} + +pub(crate) fn extract_structured_candidate(text: &str) -> Option { + let tokens = tokenize(text); + let mut best = None::; + + for start in 0..tokens.len() { + let Some(parsed) = parse_candidate(&tokens, start) else { + continue; + }; + if !candidate_is_confident(&parsed) { + continue; + } + + let replace = match &best { + Some(best) => parsed.normalized.len() > best.normalized.len(), + None => true, + }; + if replace { + best = Some(parsed); + } + } + + best.map(|parsed| StructuredCandidate { + normalized: parsed.normalized, + }) +} + +pub(crate) fn normalize_strict_structured_text(text: &str) -> Option { + let tokens = tokenize(text); + if tokens.is_empty() { + return None; + } + let parsed = parse_candidate(&tokens, 0)?; + (parsed.end == tokens.len() && candidate_is_confident(&parsed)).then_some(parsed.normalized) +} + +pub(crate) fn output_matches_candidate(text: &str, candidate: &str) -> bool { + normalize_strict_structured_text(text) + .as_deref() + .is_some_and(|normalized| normalized == candidate) + || meta_wrapped_candidate_matches(text, candidate) +} + +fn meta_wrapped_candidate_matches(text: &str, candidate: &str) -> bool { + let tokens = tokenize(text); + if tokens.is_empty() { + return false; + } + + (0..tokens.len()).any(|start| { + let Some(parsed) = parse_candidate(&tokens, start) else { + return false; + }; + candidate_is_confident(&parsed) + && parsed.normalized == candidate + && non_candidate_tokens_are_meta(&tokens[..start]) + && non_candidate_tokens_are_meta(&tokens[parsed.end..]) + }) +} + +fn non_candidate_tokens_are_meta(tokens: &[Token]) -> bool { + tokens.iter().all(|token| match token { + Token::Sep { ch, .. } => separator_is_meta_wrapper(*ch), + Token::Word { normalized, .. } => is_meta_word(normalized), + }) +} + +fn separator_is_meta_wrapper(ch: char) -> bool { + matches!(ch, '.' | ':') +} + +fn is_meta_word(word: &str) -> bool { + matches!( + word, + "a" | "an" + | "the" + | "is" + | "are" + | "was" + | "were" + | "be" + | "being" + | "called" + | "named" + | "written" + | "spelled" + | "literal" + | "literally" + | "just" + | "this" + | "that" + | "it" + | "its" + | "s" + | "url" + | "link" + | "website" + | "site" + | "domain" + | "address" + ) +} + +fn tokenize(text: &str) -> Vec { + fold_spoken_separator_tokens(raw_tokens(text)) +} + +fn raw_tokens(text: &str) -> Vec { + let chars = text.chars().collect::>(); + let mut tokens = Vec::new(); + let mut index = 0usize; + + while index < chars.len() { + let ch = chars[index]; + if ch.is_ascii_alphanumeric() { + let start = index; + index += 1; + while index < chars.len() && chars[index].is_ascii_alphanumeric() { + index += 1; + } + let text = chars[start..index].iter().collect::(); + let normalized = text.chars().flat_map(|ch| ch.to_lowercase()).collect(); + tokens.push(Token::Word { + text, + normalized, + joined_to_prev: start > 0 && !chars[start - 1].is_whitespace(), + joined_to_next: index < chars.len() && !chars[index].is_whitespace(), + }); + continue; + } + + if matches!( + ch, + '.' | '/' | ':' | '_' | '-' | '@' | '?' | '#' | '=' | '&' + ) { + tokens.push(Token::Sep { ch, spoken: false }); + } + + index += 1; + } + + tokens +} + +fn fold_spoken_separator_tokens(tokens: Vec) -> Vec { + let mut output = Vec::with_capacity(tokens.len()); + let mut index = 0usize; + + while index < tokens.len() { + if matches_word_slice(&tokens, index, &["at", "sign"]) { + output.push(Token::Sep { + ch: '@', + spoken: true, + }); + index += 2; + continue; + } + if matches_word_slice(&tokens, index, &["forward", "slash"]) { + output.push(Token::Sep { + ch: '/', + spoken: true, + }); + index += 2; + continue; + } + if matches_word_slice(&tokens, index, &["full", "stop"]) { + output.push(Token::Sep { + ch: '.', + spoken: true, + }); + index += 2; + continue; + } + + match tokens.get(index) { + Some(Token::Word { normalized, .. }) => { + let separator = match normalized.as_str() { + "dot" => Some('.'), + "period" => Some('.'), + "slash" => Some('/'), + "colon" => Some(':'), + "underscore" => Some('_'), + "dash" | "hyphen" => Some('-'), + _ => None, + }; + if let Some(ch) = separator { + output.push(Token::Sep { ch, spoken: true }); + } else { + output.push(tokens[index].clone()); + } + } + Some(_) => output.push(tokens[index].clone()), + None => {} + } + + index += 1; + } + + output +} + +fn matches_word_slice(tokens: &[Token], index: usize, words: &[&str]) -> bool { + if index + words.len() > tokens.len() { + return false; + } + + tokens[index..index + words.len()] + .iter() + .zip(words) + .all(|(token, expected)| { + matches!( + token, + Token::Word { normalized, .. } if normalized == expected + ) + }) +} + +fn parse_candidate(tokens: &[Token], start: usize) -> Option { + let mut index = start; + let mut normalized = String::new(); + let mut word_count = 0usize; + let mut dot_clusters = 0usize; + let mut has_non_dot_cluster = false; + let mut has_spoken_separator = false; + + if let Some(Token::Sep { .. }) = tokens.get(index) { + let cluster_start = index; + let mut cluster = String::new(); + let mut cluster_has_spoken = false; + + while let Some(Token::Sep { ch, spoken }) = tokens.get(index) { + cluster.push(*ch); + cluster_has_spoken |= *spoken; + index += 1; + } + + if cluster.is_empty() || !allowed_leading_separator_cluster(&cluster) { + return None; + } + + if cluster.contains('.') { + dot_clusters += 1; + } + has_non_dot_cluster |= cluster.chars().any(|ch| ch != '.'); + has_spoken_separator |= cluster_has_spoken; + normalized.push_str(&cluster); + + if index == cluster_start { + return None; + } + } + + let Token::Word { + text: first_word, + joined_to_next, + .. + } = tokens.get(index)? + else { + return None; + }; + let mut previous_word_joined_to_next = *joined_to_next; + normalized.push_str(first_word); + word_count += 1; + index += 1; + + loop { + let cluster_start = index; + let mut cluster = String::new(); + let mut cluster_has_spoken = false; + + while let Some(Token::Sep { ch, spoken }) = tokens.get(index) { + cluster.push(*ch); + cluster_has_spoken |= *spoken; + index += 1; + } + + if cluster.is_empty() || !allowed_separator_cluster(&cluster) { + index = cluster_start; + break; + } + + let Some(Token::Word { + text: word, + normalized: normalized_word, + joined_to_prev, + joined_to_next, + }) = tokens.get(index) + else { + if preserve_terminal_separator_cluster(&cluster) { + if cluster.contains('.') { + dot_clusters += 1; + } + has_non_dot_cluster |= cluster.chars().any(|ch| ch != '.'); + has_spoken_separator |= cluster_has_spoken; + normalized.push_str(&cluster); + } else { + index = cluster_start; + } + break; + }; + + if !cluster_has_spoken { + if !previous_word_joined_to_next && !separator_cluster_tolerates_leading_space(&cluster) + { + index = cluster_start; + break; + } + if !*joined_to_prev && !separator_cluster_tolerates_trailing_space(&cluster) { + index = cluster_start; + break; + } + } + + if cluster.contains('.') { + dot_clusters += 1; + } + has_non_dot_cluster |= cluster.chars().any(|ch| ch != '.'); + has_spoken_separator |= cluster_has_spoken; + normalized.push_str(&cluster); + normalized.push_str(if *joined_to_prev && !cluster_has_spoken { + word + } else { + normalized_word + }); + previous_word_joined_to_next = *joined_to_next; + word_count += 1; + index += 1; + } + + (word_count >= 2).then_some(ParsedCandidate { + normalized, + end: index, + word_count, + dot_clusters, + has_non_dot_cluster, + has_spoken_separator, + }) +} + +fn allowed_separator_cluster(cluster: &str) -> bool { + matches!( + cluster, + "." | "-" + | "_" + | "@" + | "/" + | "//" + | ":" + | ":/" + | "://" + | "?" + | "#" + | "=" + | "&" + | "/?" + | "/#" + ) +} + +fn allowed_leading_separator_cluster(cluster: &str) -> bool { + matches!( + cluster, + "/" | "//" | "./" | "../" | "." | "/." | "@" | "#" | "?" + ) +} + +fn separator_cluster_tolerates_leading_space(cluster: &str) -> bool { + cluster == "." +} + +fn separator_cluster_tolerates_trailing_space(cluster: &str) -> bool { + cluster == "." +} + +fn preserve_terminal_separator_cluster(cluster: &str) -> bool { + matches!(cluster, "/" | "?" | "#" | "=" | "&" | "/?" | "/#") +} + +fn candidate_is_confident(candidate: &ParsedCandidate) -> bool { + candidate.dot_clusters >= 2 + || candidate.has_non_dot_cluster + || (candidate.dot_clusters >= 1 && candidate.has_spoken_separator) + || (candidate.dot_clusters >= 1 && candidate.word_count >= 3) + || looks_like_bare_hostname(candidate) +} + +fn looks_like_bare_hostname(candidate: &ParsedCandidate) -> bool { + if candidate.dot_clusters != 1 + || candidate.word_count != 2 + || candidate.has_non_dot_cluster + || candidate.has_spoken_separator + { + return false; + } + + let mut labels = candidate.normalized.split('.'); + let (Some(host), Some(tld), None) = (labels.next(), labels.next(), labels.next()) else { + return false; + }; + + hostname_label_is_valid(host) && likely_hostname_tld(tld) +} + +fn hostname_label_is_valid(label: &str) -> bool { + !label.is_empty() + && label.len() <= 63 + && !label.starts_with('-') + && !label.ends_with('-') + && label + .chars() + .all(|ch| ch.is_ascii_alphanumeric() || ch == '-') + && label.chars().any(|ch| ch.is_ascii_alphabetic()) +} + +fn likely_hostname_tld(tld: &str) -> bool { + let tld_lower = tld.to_ascii_lowercase(); + hostname_label_is_valid(tld) + && ((tld.len() == 2 && tld.chars().all(|ch| ch.is_ascii_alphabetic())) + || matches!( + tld_lower.as_str(), + "com" + | "net" + | "org" + | "edu" + | "gov" + | "mil" + | "app" + | "dev" + | "info" + | "me" + | "io" + | "ai" + | "co" + | "tv" + | "fm" + | "gg" + | "blog" + | "tech" + | "site" + | "online" + | "store" + | "cloud" + )) +} + +#[cfg(test)] +mod tests { + use super::{ + extract_structured_candidate, normalize_strict_structured_text, output_matches_candidate, + }; + + #[test] + fn extracts_dotted_hostname_from_literal_periods() { + let candidate = extract_structured_candidate("portfolio. Notes. Supply is the URL") + .expect("structured candidate"); + assert_eq!(candidate.normalized, "portfolio.notes.supply"); + } + + #[test] + fn extracts_two_label_hostname_from_literal_periods() { + let candidate = extract_structured_candidate("Check example.com tomorrow") + .expect("structured candidate"); + assert_eq!(candidate.normalized, "example.com"); + } + + #[test] + fn extracts_dotted_hostname_from_spoken_separators() { + let candidate = extract_structured_candidate("portfolio dot notes dot supply") + .expect("structured candidate"); + assert_eq!(candidate.normalized, "portfolio.notes.supply"); + } + + #[test] + fn extracts_full_url_from_spoken_separators() { + let candidate = extract_structured_candidate( + "https colon slash slash portfolio dot notes dot supply slash blog", + ) + .expect("structured candidate"); + assert_eq!(candidate.normalized, "https://portfolio.notes.supply/blog"); + } + + #[test] + fn extracts_url_with_query_and_fragment_punctuation() { + let candidate = + extract_structured_candidate("Open https://example.com/?q=test&lang=en#frag now") + .expect("structured candidate"); + assert_eq!( + candidate.normalized, + "https://example.com/?q=test&lang=en#frag" + ); + } + + #[test] + fn preserves_case_for_literal_structured_text() { + let candidate = + extract_structured_candidate("Open https://Example.com/Path/Repo?Query=Value#Frag now") + .expect("structured candidate"); + assert_eq!( + candidate.normalized, + "https://Example.com/Path/Repo?Query=Value#Frag" + ); + assert_eq!( + normalize_strict_structured_text("MixedCase@Example.com"), + Some("MixedCase@Example.com".into()) + ); + } + + #[test] + fn strict_normalization_rejects_prose_suffixes() { + assert_eq!( + normalize_strict_structured_text("portfolio. Notes. Supply"), + Some("portfolio.notes.supply".into()) + ); + assert_eq!( + normalize_strict_structured_text("portfolio. Notes. Supply is the URL"), + None + ); + } + + #[test] + fn strict_normalization_preserves_trailing_separator() { + assert_eq!( + normalize_strict_structured_text("https://example.com/"), + Some("https://example.com/".into()) + ); + assert_eq!( + normalize_strict_structured_text("api/v1/"), + Some("api/v1/".into()) + ); + } + + #[test] + fn preserves_required_leading_separators() { + assert_eq!( + extract_structured_candidate("Use /api/v1 now") + .expect("structured candidate") + .normalized, + "/api/v1" + ); + assert_eq!( + extract_structured_candidate("Use @scope/package now") + .expect("structured candidate") + .normalized, + "@scope/package" + ); + assert_eq!( + normalize_strict_structured_text("/api/v1"), + Some("/api/v1".into()) + ); + assert_eq!( + normalize_strict_structured_text("@scope/package"), + Some("@scope/package".into()) + ); + assert_eq!( + normalize_strict_structured_text("./api/v1"), + Some("./api/v1".into()) + ); + assert_eq!( + normalize_strict_structured_text("/.well-known/acme"), + Some("/.well-known/acme".into()) + ); + } + + #[test] + fn preserves_prefixed_literals_without_absorbing_preceding_words() { + let candidate = + extract_structured_candidate("call it @scope/package").expect("structured candidate"); + assert_eq!(candidate.normalized, "@scope/package"); + let candidate = + extract_structured_candidate("path is /api/v1").expect("structured candidate"); + assert_eq!(candidate.normalized, "/api/v1"); + let candidate = + extract_structured_candidate("URL: example.com").expect("structured candidate"); + assert_eq!(candidate.normalized, "example.com"); + } + + #[test] + fn avoids_false_positive_for_normal_sentence() { + assert_eq!(extract_structured_candidate("Section 2. Next step"), None); + assert_eq!(extract_structured_candidate("Check two.next steps"), None); + assert_eq!( + normalize_strict_structured_text("Section 2. Next step"), + None + ); + } + + #[test] + fn candidate_match_requires_full_structured_output() { + assert!(output_matches_candidate( + "portfolio . notes . supply", + "portfolio.notes.supply" + )); + assert!(output_matches_candidate( + "portfolio. Notes. Supply is the URL", + "portfolio.notes.supply" + )); + assert!(output_matches_candidate( + "https://example.com/", + "https://example.com/" + )); + assert!(output_matches_candidate("URL: /api/v1", "/api/v1")); + assert!(!output_matches_candidate( + "https://example.com/", + "https://example.com" + )); + assert!(!output_matches_candidate("/api/v1", "api/v1")); + assert!(!output_matches_candidate("@scope/package", "scope/package")); + assert!(!output_matches_candidate("api/v1?", "api/v1")); + assert!(!output_matches_candidate( + "check portfolio. Notes. Supply tomorrow", + "portfolio.notes.supply" + )); + } +} diff --git a/src/test_support.rs b/src/test_support.rs index 3d4757f..49944ba 100644 --- a/src/test_support.rs +++ b/src/test_support.rs @@ -4,10 +4,12 @@ use std::sync::{Mutex, MutexGuard, OnceLock}; static ENV_LOCK: OnceLock> = OnceLock::new(); pub fn env_lock() -> MutexGuard<'static, ()> { - ENV_LOCK - .get_or_init(|| Mutex::new(())) - .lock() - .expect("env lock should not be poisoned") + match ENV_LOCK.get_or_init(|| Mutex::new(())).lock() { + Ok(guard) => guard, + // Keep later env-sensitive tests serialized even if an earlier test + // panicked while holding the lock. + Err(poisoned) => poisoned.into_inner(), + } } pub struct EnvVarGuard {