iwanhae 1 год назад
Родитель
Сommit
102d27427e
6 измененных файлов с 789 добавлено и 5 удалено
  1. 14 0
      searcher/.dockerignore
  2. 16 0
      searcher/Dockerfile
  3. 189 0
      searcher/ci/main.rs
  4. 568 0
      searcher/ci/model.rs
  5. 1 5
      searcher/src/args.rs
  6. 1 0
      searcher/src/main.rs

+ 14 - 0
searcher/.dockerignore

@@ -0,0 +1,14 @@
+# Generated by Cargo
+# will have compiled files and executables
+debug/
+target/
+
+# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
+# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
+Cargo.lock
+
+# These are backup files generated by rustfmt
+**/*.rs.bk
+
+# MSVC Windows builds of rustc generate these, which store debugging information
+*.pdb

+ 16 - 0
searcher/Dockerfile

@@ -0,0 +1,16 @@
+FROM rust:1.71 as builder
+WORKDIR /usr/src/kuberian
+ENV HF_HOME=/model
+COPY Cargo.toml .
+COPY ci src/
+RUN cargo run --release
+RUN rm -rf src
+COPY . .
+RUN cargo install --path .
+
+FROM debian:bullseye-slim
+ENV HF_HOME=/model
+COPY --from=builder /model /model
+COPY --from=builder /usr/local/cargo/bin/kuberian /usr/local/bin/kuberian
+EXPOSE 8080
+CMD ["kuberian"]

+ 189 - 0
searcher/ci/main.rs

@@ -0,0 +1,189 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+mod model;
+
+use anyhow::{anyhow, Error as E, Result};
+use candle_core::Tensor;
+use candle_nn::VarBuilder;
+use clap::Parser;
+use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
+use model::{BertModel, Config, DTYPE};
+use tokenizers::{PaddingParams, Tokenizer};
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+    /// Run on CPU rather than on GPU.
+    #[arg(long)]
+    cpu: bool,
+
+    /// Run offline (you must have the files already cached)
+    #[arg(long)]
+    offline: bool,
+
+    /// Enable tracing (generates a trace-timestamp.json file).
+    #[arg(long)]
+    tracing: bool,
+
+    /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
+    #[arg(long)]
+    model_id: Option<String>,
+
+    #[arg(long)]
+    revision: Option<String>,
+
+    /// When set, compute embeddings for this prompt.
+    #[arg(long)]
+    prompt: Option<String>,
+
+    /// The number of times to run the prompt.
+    #[arg(long, default_value = "1")]
+    n: usize,
+
+    /// L2 normalization for embeddings.
+    #[arg(long, default_value = "true")]
+    normalize_embeddings: bool,
+}
+
+impl Args {
+    fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
+        let device = candle_examples::device(self.cpu)?;
+        let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
+        let default_revision = "refs/pr/21".to_string();
+        let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
+            (Some(model_id), Some(revision)) => (model_id, revision),
+            (Some(model_id), None) => (model_id, "main".to_string()),
+            (None, Some(revision)) => (default_model, revision),
+            (None, None) => (default_model, default_revision),
+        };
+
+        let repo = Repo::with_revision(model_id, RepoType::Model, revision);
+        let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
+            let cache = Cache::default();
+            (
+                cache
+                    .get(&repo, "config.json")
+                    .ok_or(anyhow!("Missing config file in cache"))?,
+                cache
+                    .get(&repo, "tokenizer.json")
+                    .ok_or(anyhow!("Missing tokenizer file in cache"))?,
+                cache
+                    .get(&repo, "model.safetensors")
+                    .ok_or(anyhow!("Missing weights file in cache"))?,
+            )
+        } else {
+            let api = Api::new()?;
+            let api = api.repo(repo);
+            (
+                api.get("config.json")?,
+                api.get("tokenizer.json")?,
+                api.get("model.safetensors")?,
+            )
+        };
+        let config = std::fs::read_to_string(config_filename)?;
+        let config: Config = serde_json::from_str(&config)?;
+        let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+
+        let weights = unsafe { candle_core::safetensors::MmapedFile::new(weights_filename)? };
+        let weights = weights.deserialize()?;
+        let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
+        let model = BertModel::load(vb, &config)?;
+        Ok((model, tokenizer))
+    }
+}
+
+fn main() -> Result<()> {
+    let args = Args::parse();
+    let start = std::time::Instant::now();
+
+    let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
+    let device = &model.device;
+
+    if let Some(prompt) = args.prompt {
+        let tokenizer = tokenizer.with_padding(None).with_truncation(None);
+        let tokens = tokenizer
+            .encode(prompt, true)
+            .map_err(E::msg)?
+            .get_ids()
+            .to_vec();
+        let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
+        let token_type_ids = token_ids.zeros_like()?;
+        println!("Loaded and encoded {:?}", start.elapsed());
+        for idx in 0..args.n {
+            let start = std::time::Instant::now();
+            let ys = model.forward(&token_ids, &token_type_ids)?;
+            if idx == 0 {
+                println!("{ys}");
+            }
+            println!("Took {:?}", start.elapsed());
+        }
+    } else {
+        let sentences = [
+            "The cat sits outside",
+            "A man is playing guitar",
+            "I love pasta",
+            "The new movie is awesome",
+            "The cat plays in the garden",
+            "A woman watches TV",
+            "The new movie is so great",
+            "Do you like pizza?",
+        ];
+        let n_sentences = sentences.len();
+        if let Some(pp) = tokenizer.get_padding_mut() {
+            pp.strategy = tokenizers::PaddingStrategy::BatchLongest
+        } else {
+            let pp = PaddingParams {
+                strategy: tokenizers::PaddingStrategy::BatchLongest,
+                ..Default::default()
+            };
+            tokenizer.with_padding(Some(pp));
+        }
+        let tokens = tokenizer
+            .encode_batch(sentences.to_vec(), true)
+            .map_err(E::msg)?;
+        let token_ids = tokens
+            .iter()
+            .map(|tokens| {
+                let tokens = tokens.get_ids().to_vec();
+                Ok(Tensor::new(tokens.as_slice(), device)?)
+            })
+            .collect::<Result<Vec<_>>>()?;
+
+        let token_ids = Tensor::stack(&token_ids, 0)?;
+        let token_type_ids = token_ids.zeros_like()?;
+        println!("running inference on batch {:?}", token_ids.shape());
+        let embeddings = model.forward(&token_ids, &token_type_ids)?;
+        println!("generated embeddings {:?}", embeddings.shape());
+        // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
+        let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
+        let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
+        let embeddings = if args.normalize_embeddings {
+            normalize_l2(&embeddings)?
+        } else {
+            embeddings
+        };
+        println!("pooled embeddings {:?}", embeddings.shape());
+
+        let mut similarities = vec![];
+        for i in 0..n_sentences {
+            let e_i = embeddings.get(i)?;
+            for j in (i + 1)..n_sentences {
+                let e_j = embeddings.get(j)?;
+                let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
+                let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
+                let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
+                let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
+                similarities.push((cosine_similarity, i, j))
+            }
+        }
+        similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
+        for &(score, i, j) in similarities[..5].iter() {
+            println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
+        }
+    }
+    Ok(())
+}
+
+pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
+    Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
+}

+ 568 - 0
searcher/ci/model.rs

@@ -0,0 +1,568 @@
+use candle_core::{DType, Device, Result, Tensor};
+use candle_nn::{Embedding, VarBuilder};
+use serde::Deserialize;
+
+pub const DTYPE: DType = DType::F32;
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
+#[serde(rename_all = "lowercase")]
+enum HiddenAct {
+    Gelu,
+    Relu,
+}
+
+struct HiddenActLayer {
+    act: HiddenAct,
+    span: tracing::Span,
+}
+
+impl HiddenActLayer {
+    fn new(act: HiddenAct) -> Self {
+        let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
+        Self { act, span }
+    }
+
+    fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> {
+        let _enter = self.span.enter();
+        match self.act {
+            // TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some
+            // small numerical difference.
+            // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
+            HiddenAct::Gelu => xs.gelu(),
+            HiddenAct::Relu => xs.relu(),
+        }
+    }
+}
+
+#[derive(Debug)]
+pub struct Linear {
+    weight: Tensor,
+    bias: Option<Tensor>,
+    span: tracing::Span,
+}
+
+impl Linear {
+    pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
+        let span = tracing::span!(tracing::Level::TRACE, "linear");
+        Self { weight, bias, span }
+    }
+
+    pub fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
+        let _enter = self.span.enter();
+        let w = match x.dims() {
+            &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
+            _ => self.weight.t()?,
+        };
+        let x = x.matmul(&w)?;
+        match &self.bias {
+            None => Ok(x),
+            Some(bias) => x.broadcast_add(bias),
+        }
+    }
+}
+
+#[derive(Debug)]
+pub struct LayerNorm {
+    weight: Tensor,
+    bias: Tensor,
+    eps: f64,
+    span: tracing::Span,
+}
+
+impl LayerNorm {
+    pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
+        let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
+        Self {
+            weight,
+            bias,
+            eps,
+            span,
+        }
+    }
+
+    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+        let _enter = self.span.enter();
+        let x_dtype = x.dtype();
+        let internal_dtype = match x_dtype {
+            DType::F16 | DType::BF16 => DType::F32,
+            d => d,
+        };
+        let (_bsize, _seq_len, hidden_size) = x.dims3()?;
+        let x = x.to_dtype(internal_dtype)?;
+        let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
+        let x = x.broadcast_sub(&mean_x)?;
+        let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
+        let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
+        let x = x_normed
+            .to_dtype(x_dtype)?
+            .broadcast_mul(&self.weight)?
+            .broadcast_add(&self.bias)?;
+        Ok(x)
+    }
+}
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
+#[serde(rename_all = "lowercase")]
+enum PositionEmbeddingType {
+    #[default]
+    Absolute,
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
+#[derive(Debug, Clone, PartialEq, Deserialize)]
+pub struct Config {
+    vocab_size: usize,
+    hidden_size: usize,
+    num_hidden_layers: usize,
+    num_attention_heads: usize,
+    intermediate_size: usize,
+    hidden_act: HiddenAct,
+    hidden_dropout_prob: f64,
+    max_position_embeddings: usize,
+    type_vocab_size: usize,
+    initializer_range: f64,
+    layer_norm_eps: f64,
+    pad_token_id: usize,
+    #[serde(default)]
+    position_embedding_type: PositionEmbeddingType,
+    #[serde(default)]
+    use_cache: bool,
+    classifier_dropout: Option<f64>,
+    model_type: Option<String>,
+}
+
+impl Default for Config {
+    fn default() -> Self {
+        Self {
+            vocab_size: 30522,
+            hidden_size: 768,
+            num_hidden_layers: 12,
+            num_attention_heads: 12,
+            intermediate_size: 3072,
+            hidden_act: HiddenAct::Gelu,
+            hidden_dropout_prob: 0.1,
+            max_position_embeddings: 512,
+            type_vocab_size: 2,
+            initializer_range: 0.02,
+            layer_norm_eps: 1e-12,
+            pad_token_id: 0,
+            position_embedding_type: PositionEmbeddingType::Absolute,
+            use_cache: true,
+            classifier_dropout: None,
+            model_type: Some("bert".to_string()),
+        }
+    }
+}
+
+impl Config {
+    fn _all_mini_lm_l6_v2() -> Self {
+        // https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json
+        Self {
+            vocab_size: 30522,
+            hidden_size: 384,
+            num_hidden_layers: 6,
+            num_attention_heads: 12,
+            intermediate_size: 1536,
+            hidden_act: HiddenAct::Gelu,
+            hidden_dropout_prob: 0.1,
+            max_position_embeddings: 512,
+            type_vocab_size: 2,
+            initializer_range: 0.02,
+            layer_norm_eps: 1e-12,
+            pad_token_id: 0,
+            position_embedding_type: PositionEmbeddingType::Absolute,
+            use_cache: true,
+            classifier_dropout: None,
+            model_type: Some("bert".to_string()),
+        }
+    }
+}
+
+fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
+    let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
+    Ok(Embedding::new(embeddings, hidden_size))
+}
+
+fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
+    let weight = vb.get((size2, size1), "weight")?;
+    let bias = vb.get(size2, "bias")?;
+    Ok(Linear::new(weight, Some(bias)))
+}
+
+struct Dropout {
+    #[allow(dead_code)]
+    pr: f64,
+}
+
+impl Dropout {
+    fn new(pr: f64) -> Self {
+        Self { pr }
+    }
+
+    fn forward(&self, x: &Tensor) -> Result<Tensor> {
+        // TODO
+        Ok(x.clone())
+    }
+}
+
+fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
+    let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
+        (Ok(weight), Ok(bias)) => (weight, bias),
+        (Err(err), _) | (_, Err(err)) => {
+            if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
+                (weight, bias)
+            } else {
+                return Err(err);
+            }
+        }
+    };
+    Ok(LayerNorm::new(weight, bias, eps))
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
+struct BertEmbeddings {
+    word_embeddings: Embedding,
+    position_embeddings: Option<Embedding>,
+    token_type_embeddings: Embedding,
+    layer_norm: LayerNorm,
+    dropout: Dropout,
+    span: tracing::Span,
+}
+
+impl BertEmbeddings {
+    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+        let word_embeddings = embedding(
+            config.vocab_size,
+            config.hidden_size,
+            vb.pp("word_embeddings"),
+        )?;
+        let position_embeddings = embedding(
+            config.max_position_embeddings,
+            config.hidden_size,
+            vb.pp("position_embeddings"),
+        )?;
+        let token_type_embeddings = embedding(
+            config.type_vocab_size,
+            config.hidden_size,
+            vb.pp("token_type_embeddings"),
+        )?;
+        let layer_norm = layer_norm(
+            config.hidden_size,
+            config.layer_norm_eps,
+            vb.pp("LayerNorm"),
+        )?;
+        Ok(Self {
+            word_embeddings,
+            position_embeddings: Some(position_embeddings),
+            token_type_embeddings,
+            layer_norm,
+            dropout: Dropout::new(config.hidden_dropout_prob),
+            span: tracing::span!(tracing::Level::TRACE, "embeddings"),
+        })
+    }
+
+    fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
+        let _enter = self.span.enter();
+        let (_bsize, seq_len) = input_ids.dims2()?;
+        let input_embeddings = self.word_embeddings.forward(input_ids)?;
+        let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
+        let mut embeddings = (&input_embeddings + token_type_embeddings)?;
+        if let Some(position_embeddings) = &self.position_embeddings {
+            // TODO: Proper absolute positions?
+            let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
+            let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
+            embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
+        }
+        let embeddings = self.layer_norm.forward(&embeddings)?;
+        let embeddings = self.dropout.forward(&embeddings)?;
+        Ok(embeddings)
+    }
+}
+
+struct BertSelfAttention {
+    query: Linear,
+    key: Linear,
+    value: Linear,
+    dropout: Dropout,
+    num_attention_heads: usize,
+    attention_head_size: usize,
+    span: tracing::Span,
+    span_softmax: tracing::Span,
+}
+
+impl BertSelfAttention {
+    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+        let attention_head_size = config.hidden_size / config.num_attention_heads;
+        let all_head_size = config.num_attention_heads * attention_head_size;
+        let dropout = Dropout::new(config.hidden_dropout_prob);
+        let hidden_size = config.hidden_size;
+        let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
+        let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
+        let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
+        Ok(Self {
+            query,
+            key,
+            value,
+            dropout,
+            num_attention_heads: config.num_attention_heads,
+            attention_head_size,
+            span: tracing::span!(tracing::Level::TRACE, "self-attn"),
+            span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
+        })
+    }
+
+    fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
+        let mut new_x_shape = xs.dims().to_vec();
+        new_x_shape.pop();
+        new_x_shape.push(self.num_attention_heads);
+        new_x_shape.push(self.attention_head_size);
+        let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
+        xs.contiguous()
+    }
+
+    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+        let _enter = self.span.enter();
+        let query_layer = self.query.forward(hidden_states)?;
+        let key_layer = self.key.forward(hidden_states)?;
+        let value_layer = self.value.forward(hidden_states)?;
+
+        let query_layer = self.transpose_for_scores(&query_layer)?;
+        let key_layer = self.transpose_for_scores(&key_layer)?;
+        let value_layer = self.transpose_for_scores(&value_layer)?;
+
+        let attention_scores = query_layer.matmul(&key_layer.t()?)?;
+        let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
+        let attention_probs = {
+            let _enter_sm = self.span_softmax.enter();
+            candle_nn::ops::softmax(&attention_scores, candle_core::D::Minus1)?
+        };
+        let attention_probs = self.dropout.forward(&attention_probs)?;
+
+        let context_layer = attention_probs.matmul(&value_layer)?;
+        let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
+        let context_layer = context_layer.flatten_from(candle_core::D::Minus2)?;
+        Ok(context_layer)
+    }
+}
+
+struct BertSelfOutput {
+    dense: Linear,
+    layer_norm: LayerNorm,
+    dropout: Dropout,
+    span: tracing::Span,
+}
+
+impl BertSelfOutput {
+    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+        let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
+        let layer_norm = layer_norm(
+            config.hidden_size,
+            config.layer_norm_eps,
+            vb.pp("LayerNorm"),
+        )?;
+        let dropout = Dropout::new(config.hidden_dropout_prob);
+        Ok(Self {
+            dense,
+            layer_norm,
+            dropout,
+            span: tracing::span!(tracing::Level::TRACE, "self-out"),
+        })
+    }
+
+    fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
+        let _enter = self.span.enter();
+        let hidden_states = self.dense.forward(hidden_states)?;
+        let hidden_states = self.dropout.forward(&hidden_states)?;
+        self.layer_norm.forward(&(hidden_states + input_tensor)?)
+    }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392
+struct BertAttention {
+    self_attention: BertSelfAttention,
+    self_output: BertSelfOutput,
+    span: tracing::Span,
+}
+
+impl BertAttention {
+    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+        let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
+        let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
+        Ok(Self {
+            self_attention,
+            self_output,
+            span: tracing::span!(tracing::Level::TRACE, "attn"),
+        })
+    }
+
+    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+        let _enter = self.span.enter();
+        let self_outputs = self.self_attention.forward(hidden_states)?;
+        let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
+        Ok(attention_output)
+    }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441
+struct BertIntermediate {
+    dense: Linear,
+    intermediate_act: HiddenActLayer,
+    span: tracing::Span,
+}
+
+impl BertIntermediate {
+    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+        let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
+        Ok(Self {
+            dense,
+            intermediate_act: HiddenActLayer::new(config.hidden_act),
+            span: tracing::span!(tracing::Level::TRACE, "inter"),
+        })
+    }
+
+    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+        let _enter = self.span.enter();
+        let hidden_states = self.dense.forward(hidden_states)?;
+        let ys = self.intermediate_act.forward(&hidden_states)?;
+        Ok(ys)
+    }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456
+struct BertOutput {
+    dense: Linear,
+    layer_norm: LayerNorm,
+    dropout: Dropout,
+    span: tracing::Span,
+}
+
+impl BertOutput {
+    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+        let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
+        let layer_norm = layer_norm(
+            config.hidden_size,
+            config.layer_norm_eps,
+            vb.pp("LayerNorm"),
+        )?;
+        let dropout = Dropout::new(config.hidden_dropout_prob);
+        Ok(Self {
+            dense,
+            layer_norm,
+            dropout,
+            span: tracing::span!(tracing::Level::TRACE, "out"),
+        })
+    }
+
+    fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
+        let _enter = self.span.enter();
+        let hidden_states = self.dense.forward(hidden_states)?;
+        let hidden_states = self.dropout.forward(&hidden_states)?;
+        self.layer_norm.forward(&(hidden_states + input_tensor)?)
+    }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470
+struct BertLayer {
+    attention: BertAttention,
+    intermediate: BertIntermediate,
+    output: BertOutput,
+    span: tracing::Span,
+}
+
+impl BertLayer {
+    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+        let attention = BertAttention::load(vb.pp("attention"), config)?;
+        let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
+        let output = BertOutput::load(vb.pp("output"), config)?;
+        Ok(Self {
+            attention,
+            intermediate,
+            output,
+            span: tracing::span!(tracing::Level::TRACE, "layer"),
+        })
+    }
+
+    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+        let _enter = self.span.enter();
+        let attention_output = self.attention.forward(hidden_states)?;
+        // TODO: Support cross-attention?
+        // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
+        // TODO: Support something similar to `apply_chunking_to_forward`?
+        let intermediate_output = self.intermediate.forward(&attention_output)?;
+        let layer_output = self
+            .output
+            .forward(&intermediate_output, &attention_output)?;
+        Ok(layer_output)
+    }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
+struct BertEncoder {
+    layers: Vec<BertLayer>,
+    span: tracing::Span,
+}
+
+impl BertEncoder {
+    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+        let layers = (0..config.num_hidden_layers)
+            .map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config))
+            .collect::<Result<Vec<_>>>()?;
+        let span = tracing::span!(tracing::Level::TRACE, "encoder");
+        Ok(BertEncoder { layers, span })
+    }
+
+    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+        let _enter = self.span.enter();
+        let mut hidden_states = hidden_states.clone();
+        // Use a loop rather than a fold as it's easier to modify when adding debug/...
+        for layer in self.layers.iter() {
+            hidden_states = layer.forward(&hidden_states)?
+        }
+        Ok(hidden_states)
+    }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
+pub struct BertModel {
+    embeddings: BertEmbeddings,
+    encoder: BertEncoder,
+    pub device: Device,
+    span: tracing::Span,
+}
+
+impl BertModel {
+    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+        let (embeddings, encoder) = match (
+            BertEmbeddings::load(vb.pp("embeddings"), config),
+            BertEncoder::load(vb.pp("encoder"), config),
+        ) {
+            (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
+            (Err(err), _) | (_, Err(err)) => {
+                if let Some(model_type) = &config.model_type {
+                    if let (Ok(embeddings), Ok(encoder)) = (
+                        BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
+                        BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config),
+                    ) {
+                        (embeddings, encoder)
+                    } else {
+                        return Err(err);
+                    }
+                } else {
+                    return Err(err);
+                }
+            }
+        };
+        Ok(Self {
+            embeddings,
+            encoder,
+            device: vb.device().clone(),
+            span: tracing::span!(tracing::Level::TRACE, "model"),
+        })
+    }
+
+    pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
+        let _enter = self.span.enter();
+        let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
+        let sequence_output = self.encoder.forward(&embedding_output)?;
+        Ok(sequence_output)
+    }
+}

+ 1 - 5
searcher/src/args.rs

@@ -10,10 +10,6 @@ use tokenizers::Tokenizer;
 #[derive(Parser, Debug)]
 #[command(author, version, about, long_about = None)]
 pub struct Args {
-    /// Run on CPU rather than on GPU.
-    #[arg(long)]
-    cpu: bool,
-
     /// Run offline (you must have the files already cached)
     #[arg(long)]
     offline: bool,
@@ -28,7 +24,7 @@ pub struct Args {
 
 impl Args {
     pub fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
-        let device = candle_examples::device(self.cpu)?;
+        let device = candle_examples::device(true)?;
         let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
 
         // source: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/discussions/21

+ 1 - 0
searcher/src/main.rs

@@ -25,6 +25,7 @@ async fn main() -> std::io::Result<()> {
     let enc = encoder::Encoder::new(model, tokenizer);
     let mutexed_enc = web::Data::new(enc);
 
+    println!("Listen on 0.0.0.0:8080");
     let result = HttpServer::new(move || {
         App::new()
             .app_data(mutexed_enc.clone())