| 
					
				 | 
			
			
				@@ -1,200 +1,19 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#[cfg(feature = "mkl")] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-extern crate intel_mkl_src; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-mod model; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+mod args; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+mod embed; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-use anyhow::{anyhow, Error as E, Result}; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-use candle_core::Tensor; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-use candle_nn::VarBuilder; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+use anyhow::Result; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+use args::Args; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 use clap::Parser; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-use model::{BertModel, Config, DTYPE}; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-use tokenizers::{PaddingParams, Tokenizer}; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#[derive(Parser, Debug)] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#[command(author, version, about, long_about = None)] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-struct Args { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    /// Run on CPU rather than on GPU. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    #[arg(long)] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    cpu: bool, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    /// Run offline (you must have the files already cached) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    #[arg(long)] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    offline: bool, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    /// Enable tracing (generates a trace-timestamp.json file). 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    #[arg(long)] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    tracing: bool, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    #[arg(long)] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    model_id: Option<String>, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    #[arg(long)] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    revision: Option<String>, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    /// When set, compute embeddings for this prompt. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    #[arg(long)] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    prompt: Option<String>, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    /// The number of times to run the prompt. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    #[arg(long, default_value = "1")] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    n: usize, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    /// L2 normalization for embeddings. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    #[arg(long, default_value = "true")] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    normalize_embeddings: bool, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-impl Args { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let device = candle_examples::device(self.cpu)?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let default_revision = "refs/pr/21".to_string(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            (Some(model_id), Some(revision)) => (model_id, revision), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            (Some(model_id), None) => (model_id, "main".to_string()), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            (None, Some(revision)) => (default_model, revision), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            (None, None) => (default_model, default_revision), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        }; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let repo = Repo::with_revision(model_id, RepoType::Model, revision); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let (config_filename, tokenizer_filename, weights_filename) = if self.offline { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            let cache = Cache::default(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                cache 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    .get(&repo, "config.json") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    .ok_or(anyhow!("Missing config file in cache"))?, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                cache 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    .get(&repo, "tokenizer.json") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    .ok_or(anyhow!("Missing tokenizer file in cache"))?, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                cache 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    .get(&repo, "model.safetensors") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    .ok_or(anyhow!("Missing weights file in cache"))?, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        } else { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            let api = Api::new()?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            let api = api.repo(repo); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                api.get("config.json")?, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                api.get("tokenizer.json")?, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                api.get("model.safetensors")?, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        }; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let config = std::fs::read_to_string(config_filename)?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let config: Config = serde_json::from_str(&config)?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let weights = unsafe { candle_core::safetensors::MmapedFile::new(weights_filename)? }; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let weights = weights.deserialize()?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let model = BertModel::load(vb, &config)?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        Ok((model, tokenizer)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+use embed::encoder; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 fn main() -> Result<()> { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    use tracing_chrome::ChromeLayerBuilder; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    use tracing_subscriber::prelude::*; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     let args = Args::parse(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    let _guard = if args.tracing { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        println!("tracing..."); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        tracing_subscriber::registry().with(chrome_layer).init(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        Some(guard) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    } else { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    }; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    let start = std::time::Instant::now(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    let (model, tokenizer) = args.build_model_and_tokenizer()?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    let enc = encoder::Encoder::new(model, tokenizer); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    let (model, mut tokenizer) = args.build_model_and_tokenizer()?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    let device = &model.device; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    if let Some(prompt) = args.prompt { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let tokenizer = tokenizer.with_padding(None).with_truncation(None); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let tokens = tokenizer 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            .encode(prompt, true) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            .map_err(E::msg)? 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            .get_ids() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            .to_vec(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let token_type_ids = token_ids.zeros_like()?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        println!("Loaded and encoded {:?}", start.elapsed()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        for idx in 0..args.n { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            let start = std::time::Instant::now(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            let ys = model.forward(&token_ids, &token_type_ids)?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if idx == 0 { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                println!("{ys}"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            println!("Took {:?}", start.elapsed()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    } else { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let sentences = [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "The cat sits outside", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "A man is playing guitar", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "I love pasta", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "The new movie is awesome", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "The cat plays in the garden", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "A woman watches TV", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "The new movie is so great", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "Do you like pizza?", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let n_sentences = sentences.len(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if let Some(pp) = tokenizer.get_padding_mut() { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            pp.strategy = tokenizers::PaddingStrategy::BatchLongest 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        } else { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            let pp = PaddingParams { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                strategy: tokenizers::PaddingStrategy::BatchLongest, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                ..Default::default() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            }; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            tokenizer.with_padding(Some(pp)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let tokens = tokenizer 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            .encode_batch(sentences.to_vec(), true) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            .map_err(E::msg)?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let token_ids = tokens 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            .iter() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            .map(|tokens| { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                let tokens = tokens.get_ids().to_vec(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                Ok(Tensor::new(tokens.as_slice(), device)?) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            }) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            .collect::<Result<Vec<_>>>()?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let token_ids = Tensor::stack(&token_ids, 0)?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let token_type_ids = token_ids.zeros_like()?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        println!("running inference on batch {:?}", token_ids.shape()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let embeddings = model.forward(&token_ids, &token_type_ids)?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        println!("generated embeddings {:?}", embeddings.shape()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let embeddings = if args.normalize_embeddings { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            normalize_l2(&embeddings)? 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        } else { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            embeddings 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        }; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        println!("pooled embeddings {:?}", embeddings.shape()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        let mut similarities = vec![]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        for i in 0..n_sentences { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            let e_i = embeddings.get(i)?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            for j in (i + 1)..n_sentences { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                let e_j = embeddings.get(j)?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                similarities.push((cosine_similarity, i, j)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        similarities.sort_by(|u, v| v.0.total_cmp(&u.0)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        for &(score, i, j) in similarities[..5].iter() { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for _ in 1..10 { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        let ys = enc.encode("Hello World").unwrap(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        println!("{:?}", ys.shape()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     Ok(()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-pub fn normalize_l2(v: &Tensor) -> Result<Tensor> { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-} 
			 |