diff --git a/Cargo.lock b/Cargo.lock index 2060f593..58bffaba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -563,6 +563,19 @@ dependencies = [ "rayon", ] +[[package]] +name = "candle-index-select-cu" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d776145d1dfa52ab0c254ec678ae7765a120ee0bab999dbfea95d648c883a47b" +dependencies = [ + "anyhow", + "candle-core", + "half", + "num_cpus", + "rayon", +] + [[package]] name = "candle-kernels" version = "0.8.4" @@ -4501,6 +4514,7 @@ dependencies = [ "candle-cublaslt", "candle-flash-attn", "candle-flash-attn-v1", + "candle-index-select-cu", "candle-layer-norm", "candle-nn", "candle-rotary", diff --git a/Cargo.toml b/Cargo.toml index fef10914..77a89b58 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,6 +50,7 @@ candle-transformers = { version = "0.8" } candle-flash-attn = { version = "0.8" } candle-cublaslt = { version = "0.0.1" } candle-layer-norm = { version = "0.0.1" } +candle-index-select-cu = { version = "0.0.1", features = ["cuda-11"], default-features = false } candle-rotary = { version = "0.0.1" } candle-flash-attn-v1 = { version = "0.0.1" } half = { version = "2.3.1", features = ["num-traits"] } diff --git a/backends/candle/Cargo.toml b/backends/candle/Cargo.toml index 73d0f417..1dbbf2ca 100644 --- a/backends/candle/Cargo.toml +++ b/backends/candle/Cargo.toml @@ -15,6 +15,7 @@ candle-transformers = { workspace = true } candle-flash-attn = { workspace = true, optional = true} candle-flash-attn-v1 = { workspace = true, optional = true } candle-cublaslt = { workspace = true, optional = true } +candle-index-select-cu = { workspace = true, optional = true, features = ["cuda-11"], default-features = false} candle-layer-norm = { workspace = true, optional = true } candle-rotary = { workspace = true, optional = true } nohash-hasher = { workspace = true } @@ -41,6 +42,6 @@ anyhow = { version = "1", features = ["backtrace"] } accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"] metal = ["candle/metal", "candle-nn/metal"] mkl = ["dep:intel-mkl-src", "candle/_mkl"] -cuda = ["candle/_cuda", "candle-nn/_cuda", "dep:candle-cublaslt", "dep:candle-layer-norm", "dep:candle-rotary"] +cuda = ["candle/_cuda", "candle-nn/_cuda", "dep:candle-cublaslt", "dep:candle-layer-norm", "dep:candle-rotary", "dep:candle-index-select-cu"] flash-attn-v1 = ["dep:candle-flash-attn-v1", "cuda"] flash-attn = ["dep:candle-flash-attn", "cuda"] diff --git a/backends/candle/src/layers/index_select.rs b/backends/candle/src/layers/index_select.rs new file mode 100644 index 00000000..5d9fd809 --- /dev/null +++ b/backends/candle/src/layers/index_select.rs @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT or Apache-2.0 +// First Published under RadixMLP and https://github.com/michaelfeil/candle-index-select-cu by Michael Feil + +use candle::{Result, Tensor}; +#[cfg(feature = "cuda")] +use candle_index_select_cu; + +#[inline] +pub fn index_select(tensor: &Tensor, ids: &Tensor, dim: usize) -> Result { + #[cfg(not(feature = "cuda"))] + { + tensor.index_select(ids, dim) + } + #[cfg(feature = "cuda")] + { + candle_index_select_cu::index_select(tensor, ids, dim) + } +} diff --git a/backends/candle/src/layers/mod.rs b/backends/candle/src/layers/mod.rs index 24eb4cf7..945fc54f 100644 --- a/backends/candle/src/layers/mod.rs +++ b/backends/candle/src/layers/mod.rs @@ -1,5 +1,6 @@ #[allow(dead_code, unused)] mod cublaslt; +mod index_select; mod layer_norm; mod linear; #[allow(dead_code, unused)] @@ -11,4 +12,6 @@ pub use layer_norm::{LayerNorm, LayerNormNoBias}; pub use linear::{HiddenAct, Linear}; #[allow(unused_imports)] pub use rms_norm::RMSNorm; +#[allow(unused_imports)] +pub use index_select::index_select; pub use rotary::{apply_rotary, get_cos_sin, get_inv_freqs, RopeScaling}; diff --git a/backends/candle/src/models/flash_bert.rs b/backends/candle/src/models/flash_bert.rs index 9b14d9a0..824a6b85 100644 --- a/backends/candle/src/models/flash_bert.rs +++ b/backends/candle/src/models/flash_bert.rs @@ -1,5 +1,5 @@ use crate::flash_attn::flash_attn_varlen; -use crate::layers::{LayerNorm, Linear}; +use crate::layers::{index_select, LayerNorm, Linear}; use crate::models::bert::{ BertClassificationHead, BertConfig, BertEmbeddings, BertSpladeHead, ClassificationHead, PositionEmbeddingType, RobertaClassificationHead, @@ -419,11 +419,11 @@ impl FlashBertModel { )?; // Only select indices that requires pooling - indices = indices.index_select(&pooled_indices, 0)? + indices = index_select(&indices, &pooled_indices, 0)? } // Select tokens - Some(outputs.index_select(&indices, 0)?) + Some(index_select(&outputs, &indices, 0)?) } else { Some( match self.pool { @@ -514,7 +514,7 @@ impl FlashBertModel { Tensor::from_vec(final_indices, final_indices_length, &self.device)?; // Select the tokens with final indices - Some(outputs.index_select(&final_indices, 0)?) + Some(index_select(&outputs, &final_indices, 0)?) } else { Some(outputs) } diff --git a/backends/candle/src/models/flash_distilbert.rs b/backends/candle/src/models/flash_distilbert.rs index 2664c660..027dd545 100644 --- a/backends/candle/src/models/flash_distilbert.rs +++ b/backends/candle/src/models/flash_distilbert.rs @@ -1,5 +1,5 @@ use crate::flash_attn::flash_attn_varlen; -use crate::layers::{LayerNorm, Linear}; +use crate::layers::{index_select, LayerNorm, Linear}; use crate::models::distilbert::{ DistilBertConfig, DistilBertEmbeddings, DistilBertMLP, DistilBertSpladeHead, }; @@ -290,11 +290,11 @@ impl FlashDistilBertModel { )?; // Only select indices that requires pooling - indices = indices.index_select(&pooled_indices, 0)? + indices = index_select(&indices, &pooled_indices, 0)? } // Select tokens - outputs.index_select(&indices, 0)? + index_select(&outputs, &indices, 0)? } else { match self.pool { Pool::Cls => outputs.i(0)?, @@ -384,7 +384,7 @@ impl FlashDistilBertModel { Tensor::from_vec(final_indices, final_indices_length, &self.device)?; // Select the tokens with final indices - Some(outputs.index_select(&final_indices, 0)?) + Some(index_select(&outputs, &final_indices, 0)?) } else { Some(outputs) } diff --git a/backends/candle/src/models/flash_gte.rs b/backends/candle/src/models/flash_gte.rs index 79f80630..c1a5d511 100644 --- a/backends/candle/src/models/flash_gte.rs +++ b/backends/candle/src/models/flash_gte.rs @@ -1,5 +1,5 @@ use crate::flash_attn::flash_attn_varlen; -use crate::layers::{get_cos_sin, get_inv_freqs, LayerNorm, Linear}; +use crate::layers::{get_cos_sin, get_inv_freqs, index_select, LayerNorm, Linear}; use crate::models::gte::{GTEClassificationHead, GTEConfig, GTEMLP}; use crate::models::{Model, PositionEmbeddingType}; @@ -291,8 +291,8 @@ impl FlashGTEModel { .embeddings_norm .forward(&word_embeddings, token_type_embeddings.as_ref())?; - let cos = self.cos_cache.index_select(&position_ids, 0)?; - let sin = self.sin_cache.index_select(&position_ids, 0)?; + let cos = index_select(&self.cos_cache, &position_ids, 0)?; + let sin = index_select(&self.sin_cache, &position_ids, 0)?; for layer in &self.layers { let h = layer.forward( @@ -336,11 +336,11 @@ impl FlashGTEModel { )?; // Only select indices that requires pooling - indices = indices.index_select(&pooled_indices, 0)? + indices = index_select(&indices, &pooled_indices, 0)? } // Select tokens - Some(outputs.index_select(&indices, 0)?) + Some(index_select(&outputs, &indices, 0)?) } else { Some( match self.pool { @@ -407,7 +407,7 @@ impl FlashGTEModel { Tensor::from_vec(final_indices, final_indices_length, &self.device)?; // Select the tokens with final indices - Some(outputs.index_select(&final_indices, 0)?) + Some(index_select(&outputs, &final_indices, 0)?) } else { Some(outputs) } diff --git a/backends/candle/src/models/flash_jina.rs b/backends/candle/src/models/flash_jina.rs index 05341b84..bacf3bd1 100644 --- a/backends/candle/src/models/flash_jina.rs +++ b/backends/candle/src/models/flash_jina.rs @@ -1,6 +1,6 @@ use crate::alibi::alibi_head_slopes; use crate::flash_attn::flash_attn_varlen; -use crate::layers::{HiddenAct, LayerNorm, Linear}; +use crate::layers::{index_select, HiddenAct, LayerNorm, Linear}; use crate::models::bert::PositionEmbeddingType; use crate::models::jina::{ClassificationHead, JinaBertClassificationHead, JinaEmbeddings}; use crate::models::{BertConfig, Model}; @@ -349,11 +349,11 @@ impl FlashJinaBertModel { )?; // Only select indices that requires pooling - indices = indices.index_select(&pooled_indices, 0)? + indices = index_select(&indices, &pooled_indices, 0)? } // Select tokens - Some(outputs.index_select(&indices, 0)?) + Some(index_select(&outputs, &indices, 0)?) } else { Some( match self.pool { @@ -420,7 +420,7 @@ impl FlashJinaBertModel { Tensor::from_vec(final_indices, final_indices_length, &self.device)?; // Select the tokens with final indices - Some(outputs.index_select(&final_indices, 0)?) + Some(index_select(&outputs, &final_indices, 0)?) } else { Some(outputs) } diff --git a/backends/candle/src/models/flash_jina_code.rs b/backends/candle/src/models/flash_jina_code.rs index e00f758d..dca36103 100644 --- a/backends/candle/src/models/flash_jina_code.rs +++ b/backends/candle/src/models/flash_jina_code.rs @@ -1,6 +1,6 @@ use crate::alibi::alibi_head_slopes; use crate::flash_attn::flash_attn_varlen; -use crate::layers::{HiddenAct, LayerNorm, Linear}; +use crate::layers::{index_select, HiddenAct, LayerNorm, Linear}; use crate::models::bert::PositionEmbeddingType; use crate::models::jina::JinaEmbeddings; use crate::models::{BertConfig, Model}; @@ -395,11 +395,11 @@ impl FlashJinaCodeBertModel { )?; // Only select indices that requires pooling - indices = indices.index_select(&pooled_indices, 0)? + indices = index_select(&indices, &pooled_indices, 0)? } // Select tokens - Some(outputs.index_select(&indices, 0)?) + Some(index_select(&outputs, &indices, 0)?) } else { Some( match self.pool { @@ -466,7 +466,7 @@ impl FlashJinaCodeBertModel { Tensor::from_vec(final_indices, final_indices_length, &self.device)?; // Select the tokens with final indices - Some(outputs.index_select(&final_indices, 0)?) + Some(index_select(&outputs, &final_indices, 0)?) } else { Some(outputs) } diff --git a/backends/candle/src/models/flash_mistral.rs b/backends/candle/src/models/flash_mistral.rs index c8488f36..42d3137b 100644 --- a/backends/candle/src/models/flash_mistral.rs +++ b/backends/candle/src/models/flash_mistral.rs @@ -1,5 +1,5 @@ use crate::flash_attn::flash_attn_varlen; -use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm}; +use crate::layers::{get_cos_sin, get_inv_freqs, index_select, HiddenAct, Linear, RMSNorm}; use crate::models::{MistralConfig, Model}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; @@ -307,8 +307,8 @@ impl FlashMistralModel { let mut hidden_states = self.embeddings.forward(&input_ids)?; - let cos = self.cos_cache.index_select(&position_ids, 0)?; - let sin = self.sin_cache.index_select(&position_ids, 0)?; + let cos = index_select(&self.cos_cache, &position_ids, 0)?; + let sin = index_select(&self.sin_cache, &position_ids, 0)?; let mut residual = None; for layer in &self.layers { @@ -355,11 +355,11 @@ impl FlashMistralModel { )?; // Only select indices that requires pooling - indices = indices.index_select(&pooled_indices, 0)? + indices = index_select(&indices, &pooled_indices, 0)? } // Select tokens - Some(outputs.index_select(&indices, 0)?) + Some(index_select(&outputs, &indices, 0)?) } else { Some( match self.pool { @@ -426,7 +426,7 @@ impl FlashMistralModel { Tensor::from_vec(final_indices, final_indices_length, &self.device)?; // Select the tokens with final indices - Some(outputs.index_select(&final_indices, 0)?) + Some(index_select(&outputs, &final_indices, 0)?) } else { Some(outputs) } diff --git a/backends/candle/src/models/flash_modernbert.rs b/backends/candle/src/models/flash_modernbert.rs index 3c876c63..70de1620 100644 --- a/backends/candle/src/models/flash_modernbert.rs +++ b/backends/candle/src/models/flash_modernbert.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use crate::flash_attn::flash_attn_varlen; -use crate::layers::{get_cos_sin, get_inv_freqs, LayerNormNoBias, Linear}; +use crate::layers::{get_cos_sin, get_inv_freqs, index_select, LayerNormNoBias, Linear}; use crate::models::modernbert::{ ClassificationHead, ModernBertClassificationHead, ModernBertConfig, ModernBertEmbeddings, ModernBertMLP, @@ -343,8 +343,8 @@ impl FlashModernBertModel { for use_local_attention in [true, false] { let (cos, sin) = &self.rotary_cache[&use_local_attention]; - let cos = cos.index_select(&position_ids, 0)?; - let sin = sin.index_select(&position_ids, 0)?; + let cos = index_select(&cos, &position_ids, 0)?; + let sin = index_select(&sin, &position_ids, 0)?; rotary_cache.insert(use_local_attention, (cos, sin)); } @@ -378,10 +378,10 @@ impl FlashModernBertModel { &self.device, )?; - indices = indices.index_select(&pooled_indices, 0)? + indices = index_select(&indices, &pooled_indices, 0)? } - Some(outputs.index_select(&indices, 0)?) + Some(index_select(&outputs, &indices, 0)?) } else { Some( match self.pool { @@ -441,7 +441,7 @@ impl FlashModernBertModel { let final_indices = Tensor::from_vec(final_indices, final_indices_length, &self.device)?; - Some(outputs.index_select(&final_indices, 0)?) + Some(index_select(&outputs, &final_indices, 0)?) } else { Some(outputs) } diff --git a/backends/candle/src/models/flash_nomic.rs b/backends/candle/src/models/flash_nomic.rs index 32cd31b6..540b3c4f 100644 --- a/backends/candle/src/models/flash_nomic.rs +++ b/backends/candle/src/models/flash_nomic.rs @@ -1,5 +1,5 @@ use crate::flash_attn::flash_attn_varlen; -use crate::layers::{get_cos_sin, get_inv_freqs, LayerNorm, Linear}; +use crate::layers::{get_cos_sin, get_inv_freqs, index_select, LayerNorm, Linear}; use crate::models::nomic::{NomicBertEmbeddings, NomicMLP}; use crate::models::{Model, NomicConfig}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; @@ -285,22 +285,20 @@ impl FlashNomicBertModel { let (cos, sin) = if self.scaled_rotary_cache.is_some() && batch.max_length > self.max_trained_positions { - let cos = self - .scaled_rotary_cache - .as_ref() - .unwrap() - .0 - .index_select(&position_ids, 0)?; - let sin = self - .scaled_rotary_cache - .as_ref() - .unwrap() - .1 - .index_select(&position_ids, 0)?; + let cos = index_select( + &self.scaled_rotary_cache.as_ref().unwrap().0, + &position_ids, + 0, + )?; + let sin = index_select( + &self.scaled_rotary_cache.as_ref().unwrap().1, + &position_ids, + 0, + )?; (cos, sin) } else { - let cos = self.rotary_cache.0.index_select(&position_ids, 0)?; - let sin = self.rotary_cache.1.index_select(&position_ids, 0)?; + let cos = index_select(&self.rotary_cache.0, &position_ids, 0)?; + let sin = index_select(&self.rotary_cache.1, &position_ids, 0)?; (cos, sin) }; @@ -343,11 +341,11 @@ impl FlashNomicBertModel { )?; // Only select indices that requires pooling - indices = indices.index_select(&pooled_indices, 0)? + indices = index_select(&indices, &pooled_indices, 0)? } // Select tokens - Some(outputs.index_select(&indices, 0)?) + Some(index_select(&outputs, &indices, 0)?) } else { Some( match self.pool { diff --git a/backends/candle/src/models/flash_qwen2.rs b/backends/candle/src/models/flash_qwen2.rs index c9116311..19e71bf2 100644 --- a/backends/candle/src/models/flash_qwen2.rs +++ b/backends/candle/src/models/flash_qwen2.rs @@ -1,5 +1,5 @@ use crate::flash_attn::flash_attn_varlen; -use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm}; +use crate::layers::{get_cos_sin, get_inv_freqs, index_select, HiddenAct, Linear, RMSNorm}; use crate::models::{Model, Qwen2Config}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; @@ -325,8 +325,8 @@ impl FlashQwen2Model { let mut hidden_states = self.embeddings.forward(&input_ids)?; - let cos = self.cos_cache.index_select(&position_ids, 0)?; - let sin = self.sin_cache.index_select(&position_ids, 0)?; + let cos = index_select(&self.cos_cache, &position_ids, 0)?; + let sin = index_select(&self.sin_cache, &position_ids, 0)?; let mut residual = None; for layer in &self.layers { @@ -373,11 +373,11 @@ impl FlashQwen2Model { )?; // Only select indices that requires pooling - indices = indices.index_select(&pooled_indices, 0)? + indices = index_select(&indices, &pooled_indices, 0)? } // Select tokens - Some(outputs.index_select(&indices, 0)?) + Some(index_select(&outputs, &indices, 0)?) } else { Some( match self.pool { @@ -444,7 +444,7 @@ impl FlashQwen2Model { Tensor::from_vec(final_indices, final_indices_length, &self.device)?; // Select the tokens with final indices - Some(outputs.index_select(&final_indices, 0)?) + Some(index_select(&outputs, &final_indices, 0)?) } else { Some(outputs) } diff --git a/backends/candle/src/models/flash_qwen3.rs b/backends/candle/src/models/flash_qwen3.rs index 10f27bdd..ca0e5730 100644 --- a/backends/candle/src/models/flash_qwen3.rs +++ b/backends/candle/src/models/flash_qwen3.rs @@ -1,5 +1,5 @@ use crate::flash_attn::flash_attn_varlen; -use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm}; +use crate::layers::{get_cos_sin, get_inv_freqs, index_select, HiddenAct, Linear, RMSNorm}; use crate::models::{Model, Qwen3Config}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; @@ -373,8 +373,8 @@ impl FlashQwen3Model { let mut hidden_states = self.embeddings.forward(&input_ids)?; - let cos = self.cos_cache.index_select(&position_ids, 0)?; - let sin = self.sin_cache.index_select(&position_ids, 0)?; + let cos = index_select(&self.cos_cache, &position_ids, 0)?; + let sin = index_select(&self.sin_cache, &position_ids, 0)?; let mut residual = None; for layer in &self.layers { @@ -421,11 +421,11 @@ impl FlashQwen3Model { )?; // Only select indices that requires pooling - indices = indices.index_select(&pooled_indices, 0)? + indices = index_select(&indices, &pooled_indices, 0)? } // Select tokens - Some(outputs.index_select(&indices, 0)?) + Some(index_select(&outputs, &indices, 0)?) } else { Some( match self.pool { @@ -492,7 +492,7 @@ impl FlashQwen3Model { Tensor::from_vec(final_indices, final_indices_length, &self.device)?; // Select the tokens with final indices - Some(outputs.index_select(&final_indices, 0)?) + Some(index_select(&outputs, &final_indices, 0)?) } else { Some(outputs) }