|
@@ -1,200 +1,19 @@
|
|
|
-#[cfg(feature = "mkl")]
|
|
|
-extern crate intel_mkl_src;
|
|
|
-mod model;
|
|
|
+mod args;
|
|
|
+mod embed;
|
|
|
|
|
|
-use anyhow::{anyhow, Error as E, Result};
|
|
|
-use candle_core::Tensor;
|
|
|
-use candle_nn::VarBuilder;
|
|
|
+use anyhow::Result;
|
|
|
+use args::Args;
|
|
|
use clap::Parser;
|
|
|
-use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
|
|
-use model::{BertModel, Config, DTYPE};
|
|
|
-use tokenizers::{PaddingParams, Tokenizer};
|
|
|
-
|
|
|
-#[derive(Parser, Debug)]
|
|
|
-#[command(author, version, about, long_about = None)]
|
|
|
-struct Args {
|
|
|
- /// Run on CPU rather than on GPU.
|
|
|
- #[arg(long)]
|
|
|
- cpu: bool,
|
|
|
-
|
|
|
- /// Run offline (you must have the files already cached)
|
|
|
- #[arg(long)]
|
|
|
- offline: bool,
|
|
|
-
|
|
|
- /// Enable tracing (generates a trace-timestamp.json file).
|
|
|
- #[arg(long)]
|
|
|
- tracing: bool,
|
|
|
-
|
|
|
- /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
|
|
- #[arg(long)]
|
|
|
- model_id: Option<String>,
|
|
|
-
|
|
|
- #[arg(long)]
|
|
|
- revision: Option<String>,
|
|
|
-
|
|
|
- /// When set, compute embeddings for this prompt.
|
|
|
- #[arg(long)]
|
|
|
- prompt: Option<String>,
|
|
|
-
|
|
|
- /// The number of times to run the prompt.
|
|
|
- #[arg(long, default_value = "1")]
|
|
|
- n: usize,
|
|
|
-
|
|
|
- /// L2 normalization for embeddings.
|
|
|
- #[arg(long, default_value = "true")]
|
|
|
- normalize_embeddings: bool,
|
|
|
-}
|
|
|
-
|
|
|
-impl Args {
|
|
|
- fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
|
|
|
- let device = candle_examples::device(self.cpu)?;
|
|
|
- let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
|
|
|
- let default_revision = "refs/pr/21".to_string();
|
|
|
- let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
|
|
- (Some(model_id), Some(revision)) => (model_id, revision),
|
|
|
- (Some(model_id), None) => (model_id, "main".to_string()),
|
|
|
- (None, Some(revision)) => (default_model, revision),
|
|
|
- (None, None) => (default_model, default_revision),
|
|
|
- };
|
|
|
-
|
|
|
- let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
|
|
- let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
|
|
|
- let cache = Cache::default();
|
|
|
- (
|
|
|
- cache
|
|
|
- .get(&repo, "config.json")
|
|
|
- .ok_or(anyhow!("Missing config file in cache"))?,
|
|
|
- cache
|
|
|
- .get(&repo, "tokenizer.json")
|
|
|
- .ok_or(anyhow!("Missing tokenizer file in cache"))?,
|
|
|
- cache
|
|
|
- .get(&repo, "model.safetensors")
|
|
|
- .ok_or(anyhow!("Missing weights file in cache"))?,
|
|
|
- )
|
|
|
- } else {
|
|
|
- let api = Api::new()?;
|
|
|
- let api = api.repo(repo);
|
|
|
- (
|
|
|
- api.get("config.json")?,
|
|
|
- api.get("tokenizer.json")?,
|
|
|
- api.get("model.safetensors")?,
|
|
|
- )
|
|
|
- };
|
|
|
- let config = std::fs::read_to_string(config_filename)?;
|
|
|
- let config: Config = serde_json::from_str(&config)?;
|
|
|
- let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
|
-
|
|
|
- let weights = unsafe { candle_core::safetensors::MmapedFile::new(weights_filename)? };
|
|
|
- let weights = weights.deserialize()?;
|
|
|
- let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
|
|
- let model = BertModel::load(vb, &config)?;
|
|
|
- Ok((model, tokenizer))
|
|
|
- }
|
|
|
-}
|
|
|
+use embed::encoder;
|
|
|
|
|
|
fn main() -> Result<()> {
|
|
|
- use tracing_chrome::ChromeLayerBuilder;
|
|
|
- use tracing_subscriber::prelude::*;
|
|
|
-
|
|
|
let args = Args::parse();
|
|
|
- let _guard = if args.tracing {
|
|
|
- println!("tracing...");
|
|
|
- let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
|
- tracing_subscriber::registry().with(chrome_layer).init();
|
|
|
- Some(guard)
|
|
|
- } else {
|
|
|
- None
|
|
|
- };
|
|
|
- let start = std::time::Instant::now();
|
|
|
+ let (model, tokenizer) = args.build_model_and_tokenizer()?;
|
|
|
+ let enc = encoder::Encoder::new(model, tokenizer);
|
|
|
|
|
|
- let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
|
|
- let device = &model.device;
|
|
|
-
|
|
|
- if let Some(prompt) = args.prompt {
|
|
|
- let tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
|
|
- let tokens = tokenizer
|
|
|
- .encode(prompt, true)
|
|
|
- .map_err(E::msg)?
|
|
|
- .get_ids()
|
|
|
- .to_vec();
|
|
|
- let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
|
|
- let token_type_ids = token_ids.zeros_like()?;
|
|
|
- println!("Loaded and encoded {:?}", start.elapsed());
|
|
|
- for idx in 0..args.n {
|
|
|
- let start = std::time::Instant::now();
|
|
|
- let ys = model.forward(&token_ids, &token_type_ids)?;
|
|
|
- if idx == 0 {
|
|
|
- println!("{ys}");
|
|
|
- }
|
|
|
- println!("Took {:?}", start.elapsed());
|
|
|
- }
|
|
|
- } else {
|
|
|
- let sentences = [
|
|
|
- "The cat sits outside",
|
|
|
- "A man is playing guitar",
|
|
|
- "I love pasta",
|
|
|
- "The new movie is awesome",
|
|
|
- "The cat plays in the garden",
|
|
|
- "A woman watches TV",
|
|
|
- "The new movie is so great",
|
|
|
- "Do you like pizza?",
|
|
|
- ];
|
|
|
- let n_sentences = sentences.len();
|
|
|
- if let Some(pp) = tokenizer.get_padding_mut() {
|
|
|
- pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
|
|
- } else {
|
|
|
- let pp = PaddingParams {
|
|
|
- strategy: tokenizers::PaddingStrategy::BatchLongest,
|
|
|
- ..Default::default()
|
|
|
- };
|
|
|
- tokenizer.with_padding(Some(pp));
|
|
|
- }
|
|
|
- let tokens = tokenizer
|
|
|
- .encode_batch(sentences.to_vec(), true)
|
|
|
- .map_err(E::msg)?;
|
|
|
- let token_ids = tokens
|
|
|
- .iter()
|
|
|
- .map(|tokens| {
|
|
|
- let tokens = tokens.get_ids().to_vec();
|
|
|
- Ok(Tensor::new(tokens.as_slice(), device)?)
|
|
|
- })
|
|
|
- .collect::<Result<Vec<_>>>()?;
|
|
|
-
|
|
|
- let token_ids = Tensor::stack(&token_ids, 0)?;
|
|
|
- let token_type_ids = token_ids.zeros_like()?;
|
|
|
- println!("running inference on batch {:?}", token_ids.shape());
|
|
|
- let embeddings = model.forward(&token_ids, &token_type_ids)?;
|
|
|
- println!("generated embeddings {:?}", embeddings.shape());
|
|
|
- // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
|
|
- let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
|
|
- let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
|
|
- let embeddings = if args.normalize_embeddings {
|
|
|
- normalize_l2(&embeddings)?
|
|
|
- } else {
|
|
|
- embeddings
|
|
|
- };
|
|
|
- println!("pooled embeddings {:?}", embeddings.shape());
|
|
|
-
|
|
|
- let mut similarities = vec![];
|
|
|
- for i in 0..n_sentences {
|
|
|
- let e_i = embeddings.get(i)?;
|
|
|
- for j in (i + 1)..n_sentences {
|
|
|
- let e_j = embeddings.get(j)?;
|
|
|
- let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
|
|
- let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
|
|
|
- let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
|
|
- let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
|
|
|
- similarities.push((cosine_similarity, i, j))
|
|
|
- }
|
|
|
- }
|
|
|
- similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
|
|
|
- for &(score, i, j) in similarities[..5].iter() {
|
|
|
- println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
|
|
|
- }
|
|
|
+ for _ in 1..10 {
|
|
|
+ let ys = enc.encode("Hello World").unwrap();
|
|
|
+ println!("{:?}", ys.shape())
|
|
|
}
|
|
|
Ok(())
|
|
|
}
|
|
|
-
|
|
|
-pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
|
|
- Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
|
|
-}
|