Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7,931 changes: 6,626 additions & 1,305 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ rmcp = { version = "0.1.5", features = [
"transport-streamable-http-client",
"reqwest",
], default-features = false, git = "https://github.com/modelcontextprotocol/rust-sdk", rev = "b9d7d61" } # branch = "main"

base64 = "0.22.1"
reqwest-websocket = "0.5.0"
futures-util = "0.3.31"
Expand All @@ -59,3 +60,7 @@ tower = { version = "0.5.2", features = ["util"] }
tower-http = { version = "0.6.1", features = ["fs", "trace"] }

chrono = "0.4.41"

# vad
silero_vad_burn = "0.1.1"
burn = { version = "0.20", features = ["ndarray"] }
105 changes: 98 additions & 7 deletions src/ai/bailian/realtime_asr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ impl ParaformerRealtimeV2Asr {
Ok(())
}

pub async fn start_pcm_recognition(&mut self) -> anyhow::Result<()> {
pub async fn start_pcm_recognition(
&mut self,
semantic_punctuation_enabled: bool,
) -> anyhow::Result<()> {
let task_id = Uuid::new_v4().to_string();
log::info!("Starting asr task with ID: {}", task_id);
self.task_id = task_id;
Expand All @@ -112,6 +115,7 @@ impl ParaformerRealtimeV2Asr {
"parameters": {
"format": "pcm",
"sample_rate": self.sample_rate,
"semantic_punctuation_enabled": semantic_punctuation_enabled,
},
"input": {}
},
Expand Down Expand Up @@ -163,6 +167,7 @@ impl ParaformerRealtimeV2Asr {
"streaming": "duplex"
},
"payload": {
"task_group": "audio",
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The task_group field is added without explanation or documentation. Consider adding a comment explaining why this field is necessary and what impact it has on the ASR behavior.

Copilot uses AI. Check for mistakes.
"input": {}
}
});
Expand Down Expand Up @@ -197,6 +202,7 @@ impl ParaformerRealtimeV2Asr {
} else if let Some(output) = response.payload.output {
return Ok(Some(output.sentence));
} else {
log::error!("ASR response has no output: {:?}", text);
return Err(anyhow::anyhow!("ASR error: {:?}", text));
}
}
Expand Down Expand Up @@ -226,31 +232,116 @@ async fn test_paraformer_asr() {
let mut asr = ParaformerRealtimeV2Asr::connect("", token, head.sample_rate)
.await
.unwrap();
asr.start_pcm_recognition().await.unwrap();
asr.start_pcm_recognition(false).await.unwrap();

asr.send_audio(audio_data.clone()).await.unwrap();
asr.finish_task().await.unwrap();

loop {
if let Ok(Some(sentence)) = asr.next_result().await {
println!("{:?}", sentence);
log::info!("{:?}", sentence);
if sentence.sentence_end {
println!();
log::info!("Final sentence received, ending recognition session.");
}
} else {
break;
}
}

asr.start_pcm_recognition().await.unwrap();
asr.start_pcm_recognition(false).await.unwrap();
asr.send_audio(audio_data).await.unwrap();
asr.finish_task().await.unwrap();

loop {
if let Ok(Some(sentence)) = asr.next_result().await {
println!("{:?}", sentence);
log::info!("{:?}", sentence);
if sentence.sentence_end {
log::info!("Final sentence received, ending recognition session.");
}
} else {
break;
}
}
}

// cargo test --package echokit_server --bin echokit_server -- ai::bailian::realtime_asr::test_paraformer_stream_asr --exact --show-output
#[tokio::test]
async fn test_paraformer_stream_asr() {
env_logger::init();
let token = std::env::var("COSYVOICE_TOKEN").unwrap();

let data = std::fs::read("./resources/test/out.wav").unwrap();
let mut reader = wav_io::reader::Reader::from_vec(data).expect("Failed to create WAV reader");
let header = reader.read_header().unwrap();
let mut samples = crate::util::get_samples_f32(&mut reader).unwrap();

// pad 10 seconds of silence
samples.extend_from_slice(&[0.0; 16000 * 10]);

let samples = crate::util::convert_samples_f32_to_i16_bytes(&samples);
let audio_data = bytes::Bytes::from(samples);

let mut asr = ParaformerRealtimeV2Asr::connect("", token, header.sample_rate)
.await
.unwrap();
asr.start_pcm_recognition(true).await.unwrap();

let mut ms = 0;

for chunk in audio_data.chunks(3200) {
ms += 100;
log::info!("Sending audio chunk at {} ms", ms);
asr.send_audio(Bytes::copy_from_slice(chunk)).await.unwrap();
// tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let wait_asr_fut = asr.next_result();

let (sentence, has_result) = tokio::select! {
res = wait_asr_fut => {
(res.unwrap(),true)
}
_ = async {} => {
(None,false)
}
};

if has_result {
log::info!("{:?} {ms}", sentence);
}

if let Some(s) = sentence {
if s.sentence_end {
break;
}
}
}

asr.finish_task().await.unwrap();

loop {
if let Ok(Some(sentence)) = asr.next_result().await {
log::info!("{:?}", sentence);
if sentence.sentence_end {
log::info!("End of sentence");
}
} else {
break;
}
}

asr.start_pcm_recognition(true).await.unwrap();

ms = 0;
for chunk in audio_data.chunks(3200) {
ms += 100;
log::info!("Sending audio chunk at {} ms", ms);
asr.send_audio(Bytes::copy_from_slice(chunk)).await.unwrap();
// tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
loop {
if let Ok(Some(sentence)) = asr.next_result().await {
log::info!("{:?}", sentence);
if sentence.sentence_end {
println!();
log::info!("End of sentence");
}
} else {
break;
Expand Down
2 changes: 1 addition & 1 deletion src/ai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ pub async fn llm_stable<'p, I: IntoIterator<Item = C>, C: AsRef<llm::Content>>(
serde_json::to_string_pretty(&serde_json::json!(
{
"stream": true,
"messages": messages,
"last_message": messages.last(),
"model": model.to_string(),
"tools": tool_name,
"extra": extra,
Expand Down
108 changes: 107 additions & 1 deletion src/ai/vad.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use futures_util::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
stream::{SplitSink, SplitStream},
};
use reqwest::multipart::Part;
use reqwest_websocket::{RequestBuilderExt, WebSocket};
Expand Down Expand Up @@ -101,3 +101,109 @@ impl VadRealtimeRx {
}
}
}

pub type VadParams = crate::config::SileroVadConfig;

#[derive(Clone)]
pub struct SileroVADFactory {
device: burn::backend::ndarray::NdArrayDevice,
params: VadParams,
}

impl SileroVADFactory {
pub fn new(params: VadParams) -> anyhow::Result<Self> {
let device = burn::backend::ndarray::NdArrayDevice::default();

Ok(SileroVADFactory { device, params })
}

pub fn create_session(&self) -> anyhow::Result<VadSession> {
let vad = Box::new(silero_vad_burn::SileroVAD6Model::new(&self.device)?);
VadSession::new(&self.params, vad, self.device.clone())
}
}

pub struct VadSession {
vad: Box<silero_vad_burn::SileroVAD6Model<burn::backend::NdArray>>,
state: Option<silero_vad_burn::PredictState<burn::backend::NdArray>>,
device: burn::backend::ndarray::NdArrayDevice,

in_speech: bool,

threshold: f32,
neg_threshold: f32,

silence_chunk_count: usize,
max_silence_chunks: usize,
}

impl VadSession {
const SAMPLE_RATE: usize = 16000;

pub fn new(
params: &VadParams,
vad: Box<silero_vad_burn::SileroVAD6Model<burn::backend::NdArray>>,
device: burn::backend::ndarray::NdArrayDevice,
) -> anyhow::Result<Self> {
let state = Some(silero_vad_burn::PredictState::default(&device));

let neg_threshold = params
.neg_threshold
.unwrap_or_else(|| params.threshold - 0.15)
.max(0.05);

let threshold = params.threshold.min(0.95);
let max_silence_chunks = params.max_silence_duration_ms * (Self::SAMPLE_RATE / 1000)
/ silero_vad_burn::CHUNK_SIZE;

Ok(VadSession {
vad,
state,
device,

in_speech: false,
threshold,
neg_threshold,

silence_chunk_count: 0,
max_silence_chunks,
})
}

pub fn reset_state(&mut self) {
self.state = Some(silero_vad_burn::PredictState::default(&self.device));
self.in_speech = false;
self.silence_chunk_count = 0;
}

pub fn detect(&mut self, audio16k_chunk_512: &[f32]) -> anyhow::Result<bool> {
debug_assert!(
audio16k_chunk_512.len() <= 512,
"audio16k_chunk_512 length must be less than 512",
);

let audio_tensor =
burn::Tensor::<_, 1>::from_floats(audio16k_chunk_512, &self.device).unsqueeze();
let (state, prob) = self.vad.predict(self.state.take().unwrap(), audio_tensor)?;
self.state = Some(state);

let prob: Vec<f32> = prob.to_data().to_vec()?;

if prob[0] > self.threshold {
self.in_speech = true;
self.silence_chunk_count = 0;
} else if prob[0] < self.neg_threshold {
self.silence_chunk_count += 1;
if self.silence_chunk_count >= self.max_silence_chunks {
self.in_speech = false;
}
} else {
}

Ok(self.in_speech)
}

pub const fn vad_chunk_size() -> usize {
silero_vad_burn::CHUNK_SIZE
}
}
54 changes: 54 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,54 @@ pub enum TTSConfig {
Elevenlabs(ElevenlabsTTS),
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SileroVadConfig {
#[serde(default = "SileroVadConfig::default_threshold")]
pub threshold: f32,
#[serde(default = "SileroVadConfig::default_neg_threshold")]
pub neg_threshold: Option<f32>,
#[serde(default = "SileroVadConfig::default_min_speech_duration_ms")]
pub min_speech_duration_ms: usize,
#[serde(default = "SileroVadConfig::default_max_silence_duration_ms")]
pub max_silence_duration_ms: usize,
#[serde(default = "SileroVadConfig::hangover_ms")]
pub hangover_ms: usize,
}

impl SileroVadConfig {
pub fn default_threshold() -> f32 {
0.5
}

pub fn default_neg_threshold() -> Option<f32> {
None
}

pub fn default_min_speech_duration_ms() -> usize {
400
}

pub fn default_max_silence_duration_ms() -> usize {
200
}

pub fn hangover_ms() -> usize {
500
}
}

impl Default for SileroVadConfig {
fn default() -> Self {
SileroVadConfig {
threshold: Self::default_threshold(),
neg_threshold: Self::default_neg_threshold(),
min_speech_duration_ms: Self::default_min_speech_duration_ms(),
max_silence_duration_ms: Self::default_max_silence_duration_ms(),
hangover_ms: Self::hangover_ms(),
}
}
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct WhisperASRConfig {
pub url: String,
Expand All @@ -253,8 +301,14 @@ pub struct WhisperASRConfig {
pub model: String,
#[serde(default)]
pub prompt: String,

#[serde(default)]
pub vad: SileroVadConfig,

#[deprecated]
#[serde(default)]
pub vad_url: Option<String>,
#[deprecated]
#[serde(default)]
pub vad_realtime_url: Option<String>,
}
Expand Down
15 changes: 13 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::sync::Arc;

use axum::{routing::any, Router};
use axum::{
Router,
routing::{any, get},
};
use clap::Parser;
use config::Config;

Expand Down Expand Up @@ -180,5 +183,13 @@ async fn routes(
.layer(axum::Extension(Arc::new(real_config)));
}

router
router.route(
"/version",
get(|| async {
axum::response::Json(serde_json::json!(
{
"version": env!("CARGO_PKG_VERSION"),
}))
}),
)
}
Loading