Jelajahi Sumber

feat(searcher): encoder poc

iwanhae 1 tahun lalu
induk
melakukan
97caf0992c

+ 0 - 2
searcher/Cargo.toml

@@ -15,7 +15,5 @@ anyhow = { version = "1", features = ["backtrace"] }
 serde = { version = "1.0.171", features = ["derive"] }
 serde_json = "1.0.99"
 tracing = "0.1.37"
-tracing-chrome = "0.7.1"
-tracing-subscriber = "0.3.7"
 hf-hub = "0.2.0"
 clap = { version = "4.2.4", features = ["derive"] }

+ 88 - 0
searcher/src/args.rs

@@ -0,0 +1,88 @@
+use crate::embed;
+
+use anyhow::{anyhow, Error as E, Result};
+use candle_nn::VarBuilder;
+use clap::Parser;
+use embed::model::{BertModel, Config, DTYPE};
+use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
+use tokenizers::Tokenizer;
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+pub struct Args {
+    /// Run on CPU rather than on GPU.
+    #[arg(long)]
+    cpu: bool,
+
+    /// Run offline (you must have the files already cached)
+    #[arg(long)]
+    offline: bool,
+
+    /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
+    #[arg(long)]
+    model_id: Option<String>,
+
+    #[arg(long)]
+    revision: Option<String>,
+
+    /// When set, compute embeddings for this prompt.
+    #[arg(long)]
+    prompt: Option<String>,
+
+    /// The number of times to run the prompt.
+    #[arg(long, default_value = "1")]
+    n: usize,
+
+    /// L2 normalization for embeddings.
+    #[arg(long, default_value = "true")]
+    normalize_embeddings: bool,
+}
+
+impl Args {
+    pub fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
+        let device = candle_examples::device(self.cpu)?;
+        let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
+
+        // source: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/discussions/21
+        let default_revision = "refs/pr/21".to_string();
+        let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
+            (Some(model_id), Some(revision)) => (model_id, revision),
+            (Some(model_id), None) => (model_id, "main".to_string()),
+            (None, Some(revision)) => (default_model, revision),
+            (None, None) => (default_model, default_revision),
+        };
+
+        let repo = Repo::with_revision(model_id, RepoType::Model, revision);
+        let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
+            let cache = Cache::default();
+            (
+                cache
+                    .get(&repo, "config.json")
+                    .ok_or(anyhow!("Missing config file in cache"))?,
+                cache
+                    .get(&repo, "tokenizer.json")
+                    .ok_or(anyhow!("Missing tokenizer file in cache"))?,
+                cache
+                    .get(&repo, "model.safetensors")
+                    .ok_or(anyhow!("Missing weights file in cache"))?,
+            )
+        } else {
+            let api = Api::new()?;
+            let api = api.repo(repo);
+            (
+                api.get("config.json")?,
+                api.get("tokenizer.json")?,
+                api.get("model.safetensors")?,
+            )
+        };
+        let config = std::fs::read_to_string(config_filename)?;
+        let config: Config = serde_json::from_str(&config)?;
+        let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+
+        let weights = unsafe { candle_core::safetensors::MmapedFile::new(weights_filename)? };
+        let weights = weights.deserialize()?;
+        let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
+        let model = BertModel::load(vb, &config)?;
+        Ok((model, tokenizer))
+    }
+}

+ 39 - 0
searcher/src/embed/encoder.rs

@@ -0,0 +1,39 @@
+use super::model;
+
+use anyhow::Error as E;
+use candle_core::{Result, Tensor};
+use std::time::Instant;
+use tokenizers::Tokenizer;
+
+pub struct Encoder {
+    model: model::BertModel,
+    tokenizer: Tokenizer,
+}
+
+impl Encoder {
+    pub fn new(model: model::BertModel, tokenizer: Tokenizer) -> Self {
+        let tokenizer = tokenizer.clone();
+        Encoder { model, tokenizer }
+    }
+    pub fn encode(&self, prompt: &str) -> Result<Tensor> {
+        let start = Instant::now();
+        let tokens = self
+            .tokenizer
+            .encode(prompt, true)
+            .map_err(E::msg)
+            .unwrap()
+            .get_ids()
+            .to_vec();
+        let token_ids = Tensor::new(&tokens[..], &self.model.device)?.unsqueeze(0)?;
+        let token_type_ids = token_ids.zeros_like()?;
+        let embeddings = self.model.forward(&token_ids, &token_type_ids)?;
+        let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
+
+        // mean pooling
+        assert_eq!(embeddings.shape().dims(), [1, 128, 384]);
+        let embeddings = embeddings.sum(1)? / (n_tokens as f64);
+
+        dbg!(start.elapsed());
+        embeddings?.get(0)
+    }
+}

+ 2 - 0
searcher/src/embed/mod.rs

@@ -0,0 +1,2 @@
+pub mod encoder;
+pub mod model;

+ 0 - 0
searcher/src/model.rs → searcher/src/embed/model.rs


+ 10 - 191
searcher/src/main.rs

@@ -1,200 +1,19 @@
-#[cfg(feature = "mkl")]
-extern crate intel_mkl_src;
-mod model;
+mod args;
+mod embed;
 
-use anyhow::{anyhow, Error as E, Result};
-use candle_core::Tensor;
-use candle_nn::VarBuilder;
+use anyhow::Result;
+use args::Args;
 use clap::Parser;
-use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
-use model::{BertModel, Config, DTYPE};
-use tokenizers::{PaddingParams, Tokenizer};
-
-#[derive(Parser, Debug)]
-#[command(author, version, about, long_about = None)]
-struct Args {
-    /// Run on CPU rather than on GPU.
-    #[arg(long)]
-    cpu: bool,
-
-    /// Run offline (you must have the files already cached)
-    #[arg(long)]
-    offline: bool,
-
-    /// Enable tracing (generates a trace-timestamp.json file).
-    #[arg(long)]
-    tracing: bool,
-
-    /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
-    #[arg(long)]
-    model_id: Option<String>,
-
-    #[arg(long)]
-    revision: Option<String>,
-
-    /// When set, compute embeddings for this prompt.
-    #[arg(long)]
-    prompt: Option<String>,
-
-    /// The number of times to run the prompt.
-    #[arg(long, default_value = "1")]
-    n: usize,
-
-    /// L2 normalization for embeddings.
-    #[arg(long, default_value = "true")]
-    normalize_embeddings: bool,
-}
-
-impl Args {
-    fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
-        let device = candle_examples::device(self.cpu)?;
-        let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
-        let default_revision = "refs/pr/21".to_string();
-        let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
-            (Some(model_id), Some(revision)) => (model_id, revision),
-            (Some(model_id), None) => (model_id, "main".to_string()),
-            (None, Some(revision)) => (default_model, revision),
-            (None, None) => (default_model, default_revision),
-        };
-
-        let repo = Repo::with_revision(model_id, RepoType::Model, revision);
-        let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
-            let cache = Cache::default();
-            (
-                cache
-                    .get(&repo, "config.json")
-                    .ok_or(anyhow!("Missing config file in cache"))?,
-                cache
-                    .get(&repo, "tokenizer.json")
-                    .ok_or(anyhow!("Missing tokenizer file in cache"))?,
-                cache
-                    .get(&repo, "model.safetensors")
-                    .ok_or(anyhow!("Missing weights file in cache"))?,
-            )
-        } else {
-            let api = Api::new()?;
-            let api = api.repo(repo);
-            (
-                api.get("config.json")?,
-                api.get("tokenizer.json")?,
-                api.get("model.safetensors")?,
-            )
-        };
-        let config = std::fs::read_to_string(config_filename)?;
-        let config: Config = serde_json::from_str(&config)?;
-        let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
-
-        let weights = unsafe { candle_core::safetensors::MmapedFile::new(weights_filename)? };
-        let weights = weights.deserialize()?;
-        let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
-        let model = BertModel::load(vb, &config)?;
-        Ok((model, tokenizer))
-    }
-}
+use embed::encoder;
 
 fn main() -> Result<()> {
-    use tracing_chrome::ChromeLayerBuilder;
-    use tracing_subscriber::prelude::*;
-
     let args = Args::parse();
-    let _guard = if args.tracing {
-        println!("tracing...");
-        let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
-        tracing_subscriber::registry().with(chrome_layer).init();
-        Some(guard)
-    } else {
-        None
-    };
-    let start = std::time::Instant::now();
+    let (model, tokenizer) = args.build_model_and_tokenizer()?;
+    let enc = encoder::Encoder::new(model, tokenizer);
 
-    let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
-    let device = &model.device;
-
-    if let Some(prompt) = args.prompt {
-        let tokenizer = tokenizer.with_padding(None).with_truncation(None);
-        let tokens = tokenizer
-            .encode(prompt, true)
-            .map_err(E::msg)?
-            .get_ids()
-            .to_vec();
-        let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
-        let token_type_ids = token_ids.zeros_like()?;
-        println!("Loaded and encoded {:?}", start.elapsed());
-        for idx in 0..args.n {
-            let start = std::time::Instant::now();
-            let ys = model.forward(&token_ids, &token_type_ids)?;
-            if idx == 0 {
-                println!("{ys}");
-            }
-            println!("Took {:?}", start.elapsed());
-        }
-    } else {
-        let sentences = [
-            "The cat sits outside",
-            "A man is playing guitar",
-            "I love pasta",
-            "The new movie is awesome",
-            "The cat plays in the garden",
-            "A woman watches TV",
-            "The new movie is so great",
-            "Do you like pizza?",
-        ];
-        let n_sentences = sentences.len();
-        if let Some(pp) = tokenizer.get_padding_mut() {
-            pp.strategy = tokenizers::PaddingStrategy::BatchLongest
-        } else {
-            let pp = PaddingParams {
-                strategy: tokenizers::PaddingStrategy::BatchLongest,
-                ..Default::default()
-            };
-            tokenizer.with_padding(Some(pp));
-        }
-        let tokens = tokenizer
-            .encode_batch(sentences.to_vec(), true)
-            .map_err(E::msg)?;
-        let token_ids = tokens
-            .iter()
-            .map(|tokens| {
-                let tokens = tokens.get_ids().to_vec();
-                Ok(Tensor::new(tokens.as_slice(), device)?)
-            })
-            .collect::<Result<Vec<_>>>()?;
-
-        let token_ids = Tensor::stack(&token_ids, 0)?;
-        let token_type_ids = token_ids.zeros_like()?;
-        println!("running inference on batch {:?}", token_ids.shape());
-        let embeddings = model.forward(&token_ids, &token_type_ids)?;
-        println!("generated embeddings {:?}", embeddings.shape());
-        // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
-        let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
-        let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
-        let embeddings = if args.normalize_embeddings {
-            normalize_l2(&embeddings)?
-        } else {
-            embeddings
-        };
-        println!("pooled embeddings {:?}", embeddings.shape());
-
-        let mut similarities = vec![];
-        for i in 0..n_sentences {
-            let e_i = embeddings.get(i)?;
-            for j in (i + 1)..n_sentences {
-                let e_j = embeddings.get(j)?;
-                let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
-                let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
-                let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
-                let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
-                similarities.push((cosine_similarity, i, j))
-            }
-        }
-        similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
-        for &(score, i, j) in similarities[..5].iter() {
-            println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
-        }
+    for _ in 1..10 {
+        let ys = enc.encode("Hello World").unwrap();
+        println!("{:?}", ys.shape())
     }
     Ok(())
 }
-
-pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
-    Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
-}