diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index ff824f555..9348000df 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -115,6 +115,23 @@ enum Config { XlmRoberta(BertConfig), } +/// Helper function to check if flash attention should be used +/// require_fa2_capabilities: true for models that need flash attention v2, false for v1/v2 compatible models +fn should_use_flash_attention(require_fa2_capabilities: bool) -> bool { + if cfg!(not(feature = "cuda")) { + // if not cuda support, always false for now. + return false; + }; + + let flash_attn_enabled = &std::env::var("USE_FLASH_ATTENTION").unwrap_or("true".to_string()).to_lowercase() == "true"; + // cuda + if require_fa2_capabilities { + cfg!(feature = "flash-attn") && flash_attn_enabled + } else { + cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) && flash_attn_enabled + } +} + pub struct CandleBackend { device: Device, model: Box, @@ -307,12 +324,7 @@ impl CandleBackend { } #[cfg(feature = "cuda")] (Config::Bert(config), Device::Cuda(_)) => { - if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) - && dtype == DType::F16 - // Allow disabling because of flash attention v1 precision problems - // See: https://github.com/huggingface/text-embeddings-inference/issues/37 - && &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true" - { + if dtype == DType::F16 && should_use_flash_attention(false) { match config { BertConfigWrapper::JinaBert(config) => { tracing::info!("Starting FlashJinaBert model on {:?}", device); @@ -354,13 +366,8 @@ impl CandleBackend { ( Config::Camembert(config) | Config::Roberta(config) | Config::XlmRoberta(config), Device::Cuda(_), - ) => { - if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) - && dtype == DType::F16 - // Allow disabling because of flash attention v1 precision problems - // See: https://github.com/huggingface/text-embeddings-inference/issues/37 - && &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true" - { +) => { + if dtype == DType::F16 && should_use_flash_attention(false) { tracing::info!("Starting FlashBert model on {:?}", device); Ok(Box::new( FlashBertModel::load_roberta(vb, &config, model_type).s()?, @@ -374,12 +381,7 @@ impl CandleBackend { } #[cfg(feature = "cuda")] (Config::DistilBert(config), Device::Cuda(_)) => { - if cfg!(feature = "flash-attn") - && dtype == DType::F16 - && &std::env::var("USE_FLASH_ATTENTION") - .unwrap_or("True".to_string()) - .to_lowercase() - == "true" + if dtype == DType::F16 && should_use_flash_attention(true) { tracing::info!("Starting FlashDistilBert model on {:?}", device); Ok(Box::new( @@ -405,12 +407,7 @@ impl CandleBackend { } #[cfg(feature = "cuda")] (Config::Gte(config), Device::Cuda(_)) => { - if dtype != DType::F16 - || !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) - || &std::env::var("USE_FLASH_ATTENTION") - .unwrap_or("True".to_string()) - .to_lowercase() - != "true" + if dtype != DType::F16 || !should_use_flash_attention(false) { tracing::info!("Starting GTE model on {:?}", device); Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?)) @@ -421,13 +418,7 @@ impl CandleBackend { } #[cfg(feature = "cuda")] (Config::Mistral(config), Device::Cuda(_)) => { - if dtype != DType::F16 - || !cfg!(feature = "flash-attn") - || get_runtime_compute_cap().unwrap() < 80 - || &std::env::var("USE_FLASH_ATTENTION") - .unwrap_or("True".to_string()) - .to_lowercase() - != "true" + if dtype != DType::F16 || !should_use_flash_attention(true) { return Err(BackendError::Start("Mistral is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string())); } @@ -438,11 +429,7 @@ impl CandleBackend { } #[cfg(feature = "cuda")] (Config::ModernBert(config), Device::Cuda(_)) => { - if cfg!(feature = "flash-attn") - && dtype == DType::F16 - // Allow disabling because of flash attention v1 precision problems - // See: https://github.com/huggingface/text-embeddings-inference/issues/37 - && &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true" + if dtype == DType::F16 && should_use_flash_attention(true) { tracing::info!("Starting FlashModernBert model on {:?}", device); Ok(Box::new( @@ -459,12 +446,7 @@ impl CandleBackend { } #[cfg(feature = "cuda")] (Config::NomicBert(config), Device::Cuda(_)) => { - if cfg!(feature = "flash-attn") - && dtype == DType::F16 - && &std::env::var("USE_FLASH_ATTENTION") - .unwrap_or("True".to_string()) - .to_lowercase() - == "true" + if dtype == DType::F16 && should_use_flash_attention(true) { tracing::info!("Starting FlashNomicBert model on {:?}", device); Ok(Box::new( @@ -477,12 +459,7 @@ impl CandleBackend { } #[cfg(feature = "cuda")] (Config::Qwen2(config), Device::Cuda(_)) => { - if dtype != DType::F16 - || !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) - || &std::env::var("USE_FLASH_ATTENTION") - .unwrap_or("True".to_string()) - .to_lowercase() - != "true" + if dtype != DType::F16 || !should_use_flash_attention(false) { return Err(BackendError::Start("Qwen2 is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string())); } @@ -493,12 +470,7 @@ impl CandleBackend { } #[cfg(feature = "cuda")] (Config::Qwen3(config), Device::Cuda(_)) => { - if dtype != DType::F16 - || !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) - || &std::env::var("USE_FLASH_ATTENTION") - .unwrap_or("True".to_string()) - .to_lowercase() - != "true" + if dtype != DType::F16 || !should_use_flash_attention(false) { tracing::info!("Starting Qwen3 model on {:?}", device); Ok(Box::new(Qwen3Model::load(vb, &config, model_type).s()?))