main.rs 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. #[cfg(feature = "mkl")]
  2. extern crate intel_mkl_src;
  3. mod model;
  4. use anyhow::{anyhow, Error as E, Result};
  5. use candle_core::Tensor;
  6. use candle_nn::VarBuilder;
  7. use clap::Parser;
  8. use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
  9. use model::{BertModel, Config, DTYPE};
  10. use tokenizers::{PaddingParams, Tokenizer};
  11. #[derive(Parser, Debug)]
  12. #[command(author, version, about, long_about = None)]
  13. struct Args {
  14. /// Run on CPU rather than on GPU.
  15. #[arg(long)]
  16. cpu: bool,
  17. /// Run offline (you must have the files already cached)
  18. #[arg(long)]
  19. offline: bool,
  20. /// Enable tracing (generates a trace-timestamp.json file).
  21. #[arg(long)]
  22. tracing: bool,
  23. /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
  24. #[arg(long)]
  25. model_id: Option<String>,
  26. #[arg(long)]
  27. revision: Option<String>,
  28. /// When set, compute embeddings for this prompt.
  29. #[arg(long)]
  30. prompt: Option<String>,
  31. /// The number of times to run the prompt.
  32. #[arg(long, default_value = "1")]
  33. n: usize,
  34. /// L2 normalization for embeddings.
  35. #[arg(long, default_value = "true")]
  36. normalize_embeddings: bool,
  37. }
  38. impl Args {
  39. fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
  40. let device = candle_examples::device(self.cpu)?;
  41. let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
  42. let default_revision = "refs/pr/21".to_string();
  43. let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
  44. (Some(model_id), Some(revision)) => (model_id, revision),
  45. (Some(model_id), None) => (model_id, "main".to_string()),
  46. (None, Some(revision)) => (default_model, revision),
  47. (None, None) => (default_model, default_revision),
  48. };
  49. let repo = Repo::with_revision(model_id, RepoType::Model, revision);
  50. let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
  51. let cache = Cache::default();
  52. (
  53. cache
  54. .get(&repo, "config.json")
  55. .ok_or(anyhow!("Missing config file in cache"))?,
  56. cache
  57. .get(&repo, "tokenizer.json")
  58. .ok_or(anyhow!("Missing tokenizer file in cache"))?,
  59. cache
  60. .get(&repo, "model.safetensors")
  61. .ok_or(anyhow!("Missing weights file in cache"))?,
  62. )
  63. } else {
  64. let api = Api::new()?;
  65. let api = api.repo(repo);
  66. (
  67. api.get("config.json")?,
  68. api.get("tokenizer.json")?,
  69. api.get("model.safetensors")?,
  70. )
  71. };
  72. let config = std::fs::read_to_string(config_filename)?;
  73. let config: Config = serde_json::from_str(&config)?;
  74. let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
  75. let weights = unsafe { candle_core::safetensors::MmapedFile::new(weights_filename)? };
  76. let weights = weights.deserialize()?;
  77. let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
  78. let model = BertModel::load(vb, &config)?;
  79. Ok((model, tokenizer))
  80. }
  81. }
  82. fn main() -> Result<()> {
  83. let args = Args::parse();
  84. let start = std::time::Instant::now();
  85. let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
  86. let device = &model.device;
  87. if let Some(prompt) = args.prompt {
  88. let tokenizer = tokenizer.with_padding(None).with_truncation(None);
  89. let tokens = tokenizer
  90. .encode(prompt, true)
  91. .map_err(E::msg)?
  92. .get_ids()
  93. .to_vec();
  94. let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
  95. let token_type_ids = token_ids.zeros_like()?;
  96. println!("Loaded and encoded {:?}", start.elapsed());
  97. for idx in 0..args.n {
  98. let start = std::time::Instant::now();
  99. let ys = model.forward(&token_ids, &token_type_ids)?;
  100. if idx == 0 {
  101. println!("{ys}");
  102. }
  103. println!("Took {:?}", start.elapsed());
  104. }
  105. } else {
  106. let sentences = [
  107. "The cat sits outside",
  108. "A man is playing guitar",
  109. "I love pasta",
  110. "The new movie is awesome",
  111. "The cat plays in the garden",
  112. "A woman watches TV",
  113. "The new movie is so great",
  114. "Do you like pizza?",
  115. ];
  116. let n_sentences = sentences.len();
  117. if let Some(pp) = tokenizer.get_padding_mut() {
  118. pp.strategy = tokenizers::PaddingStrategy::BatchLongest
  119. } else {
  120. let pp = PaddingParams {
  121. strategy: tokenizers::PaddingStrategy::BatchLongest,
  122. ..Default::default()
  123. };
  124. tokenizer.with_padding(Some(pp));
  125. }
  126. let tokens = tokenizer
  127. .encode_batch(sentences.to_vec(), true)
  128. .map_err(E::msg)?;
  129. let token_ids = tokens
  130. .iter()
  131. .map(|tokens| {
  132. let tokens = tokens.get_ids().to_vec();
  133. Ok(Tensor::new(tokens.as_slice(), device)?)
  134. })
  135. .collect::<Result<Vec<_>>>()?;
  136. let token_ids = Tensor::stack(&token_ids, 0)?;
  137. let token_type_ids = token_ids.zeros_like()?;
  138. println!("running inference on batch {:?}", token_ids.shape());
  139. let embeddings = model.forward(&token_ids, &token_type_ids)?;
  140. println!("generated embeddings {:?}", embeddings.shape());
  141. // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
  142. let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
  143. let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
  144. let embeddings = if args.normalize_embeddings {
  145. normalize_l2(&embeddings)?
  146. } else {
  147. embeddings
  148. };
  149. println!("pooled embeddings {:?}", embeddings.shape());
  150. let mut similarities = vec![];
  151. for i in 0..n_sentences {
  152. let e_i = embeddings.get(i)?;
  153. for j in (i + 1)..n_sentences {
  154. let e_j = embeddings.get(j)?;
  155. let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
  156. let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
  157. let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
  158. let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
  159. similarities.push((cosine_similarity, i, j))
  160. }
  161. }
  162. similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
  163. for &(score, i, j) in similarities[..5].iter() {
  164. println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
  165. }
  166. }
  167. Ok(())
  168. }
  169. pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
  170. Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
  171. }