Przeglądaj źródła

feat(search): support intel mkl for x86_64

iwanhae 1 rok temu
rodzic
commit
6f083b3bf9

+ 12 - 1
searcher/.devcontainer/Dockerfile

@@ -1,2 +1,13 @@
 ARG VARIANT="bullseye"
-FROM mcr.microsoft.com/vscode/devcontainers/rust:1-${VARIANT}
+FROM mcr.microsoft.com/vscode/devcontainers/rust:1-${VARIANT}
+
+RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
+  | gpg --dearmor | sudo tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \
+  echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | \
+  sudo tee /etc/apt/sources.list.d/oneAPI.list
+
+RUN apt update
+# for builder
+RUN apt install -y intel-oneapi-mkl-devel
+# for runtime environment
+RUN apt install -y libomp-dev

+ 2 - 1
searcher/.devcontainer/devcontainer.json

@@ -27,7 +27,8 @@
                 "mutantdino.resourcemonitor",
                 "rust-lang.rust-analyzer",
                 "tamasfe.even-better-toml",
-                "serayuzgur.crates"
+                "serayuzgur.crates",
+                "mhutchie.git-graph"
             ]
         }
     },

+ 10 - 2
searcher/Cargo.toml

@@ -6,8 +6,8 @@ edition = "2021"
 # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 
 [dependencies]
-candle-core = { git = "https://github.com/huggingface/candle.git" }
-candle-nn = { git = "https://github.com/huggingface/candle.git" }
+candle-core = { version = "0.1.0" }
+candle-nn = { version = "0.1.0" }
 tokenizers = { version = "0.13.3", default-features = true }
 anyhow = { version = "1", features = ["backtrace"] }
 serde = { version = "1.0.171", features = ["derive"] }
@@ -16,3 +16,11 @@ tracing = "0.1.37"
 hf-hub = "0.2.0"
 clap = { version = "4.2.4", features = ["derive"] }
 actix-web = "4"
+intel-mkl-src = { version = "0.8.1", features = [
+    "mkl-static-lp64-iomp",
+], optional = true }
+
+
+[features]
+default = []
+mkl = ["dep:intel-mkl-src", "candle-core/mkl", "candle-nn/mkl"]

+ 14 - 3
searcher/Dockerfile

@@ -1,14 +1,25 @@
-FROM rust:1.71 as builder
+FROM rust:1.71-bullseye as builder
+RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
+    | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \
+    echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | \
+    tee /etc/apt/sources.list.d/oneAPI.list
+RUN apt update
+RUN apt install -y intel-oneapi-mkl-devel libomp-dev
+
 WORKDIR /usr/src/kuberian
 ENV HF_HOME=/model
 COPY Cargo.toml .
 COPY ci src/
-RUN cargo run --release
+RUN cargo build -r
 RUN rm -rf src
 COPY . .
-RUN cargo install --path .
+RUN cargo install --path . -F mkl
+RUN kuberian --ci
 
 FROM debian:bullseye-slim
+RUN apt update && \ 
+    apt install -y libomp-dev && \
+    rm -rf /var/lib/apt/lists/*
 ENV HF_HOME=/model
 COPY --from=builder /model /model
 COPY --from=builder /usr/local/cargo/bin/kuberian /usr/local/bin/kuberian

+ 2 - 188
searcher/ci/main.rs

@@ -1,189 +1,3 @@
-#[cfg(feature = "mkl")]
-extern crate intel_mkl_src;
-mod model;
-
-use anyhow::{anyhow, Error as E, Result};
-use candle_core::Tensor;
-use candle_nn::VarBuilder;
-use clap::Parser;
-use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
-use model::{BertModel, Config, DTYPE};
-use tokenizers::{PaddingParams, Tokenizer};
-
-#[derive(Parser, Debug)]
-#[command(author, version, about, long_about = None)]
-struct Args {
-    /// Run on CPU rather than on GPU.
-    #[arg(long)]
-    cpu: bool,
-
-    /// Run offline (you must have the files already cached)
-    #[arg(long)]
-    offline: bool,
-
-    /// Enable tracing (generates a trace-timestamp.json file).
-    #[arg(long)]
-    tracing: bool,
-
-    /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
-    #[arg(long)]
-    model_id: Option<String>,
-
-    #[arg(long)]
-    revision: Option<String>,
-
-    /// When set, compute embeddings for this prompt.
-    #[arg(long)]
-    prompt: Option<String>,
-
-    /// The number of times to run the prompt.
-    #[arg(long, default_value = "1")]
-    n: usize,
-
-    /// L2 normalization for embeddings.
-    #[arg(long, default_value = "true")]
-    normalize_embeddings: bool,
-}
-
-impl Args {
-    fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
-        let device = candle_examples::device(self.cpu)?;
-        let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
-        let default_revision = "refs/pr/21".to_string();
-        let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
-            (Some(model_id), Some(revision)) => (model_id, revision),
-            (Some(model_id), None) => (model_id, "main".to_string()),
-            (None, Some(revision)) => (default_model, revision),
-            (None, None) => (default_model, default_revision),
-        };
-
-        let repo = Repo::with_revision(model_id, RepoType::Model, revision);
-        let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
-            let cache = Cache::default();
-            (
-                cache
-                    .get(&repo, "config.json")
-                    .ok_or(anyhow!("Missing config file in cache"))?,
-                cache
-                    .get(&repo, "tokenizer.json")
-                    .ok_or(anyhow!("Missing tokenizer file in cache"))?,
-                cache
-                    .get(&repo, "model.safetensors")
-                    .ok_or(anyhow!("Missing weights file in cache"))?,
-            )
-        } else {
-            let api = Api::new()?;
-            let api = api.repo(repo);
-            (
-                api.get("config.json")?,
-                api.get("tokenizer.json")?,
-                api.get("model.safetensors")?,
-            )
-        };
-        let config = std::fs::read_to_string(config_filename)?;
-        let config: Config = serde_json::from_str(&config)?;
-        let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
-
-        let weights = unsafe { candle_core::safetensors::MmapedFile::new(weights_filename)? };
-        let weights = weights.deserialize()?;
-        let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
-        let model = BertModel::load(vb, &config)?;
-        Ok((model, tokenizer))
-    }
-}
-
-fn main() -> Result<()> {
-    let args = Args::parse();
-    let start = std::time::Instant::now();
-
-    let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
-    let device = &model.device;
-
-    if let Some(prompt) = args.prompt {
-        let tokenizer = tokenizer.with_padding(None).with_truncation(None);
-        let tokens = tokenizer
-            .encode(prompt, true)
-            .map_err(E::msg)?
-            .get_ids()
-            .to_vec();
-        let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
-        let token_type_ids = token_ids.zeros_like()?;
-        println!("Loaded and encoded {:?}", start.elapsed());
-        for idx in 0..args.n {
-            let start = std::time::Instant::now();
-            let ys = model.forward(&token_ids, &token_type_ids)?;
-            if idx == 0 {
-                println!("{ys}");
-            }
-            println!("Took {:?}", start.elapsed());
-        }
-    } else {
-        let sentences = [
-            "The cat sits outside",
-            "A man is playing guitar",
-            "I love pasta",
-            "The new movie is awesome",
-            "The cat plays in the garden",
-            "A woman watches TV",
-            "The new movie is so great",
-            "Do you like pizza?",
-        ];
-        let n_sentences = sentences.len();
-        if let Some(pp) = tokenizer.get_padding_mut() {
-            pp.strategy = tokenizers::PaddingStrategy::BatchLongest
-        } else {
-            let pp = PaddingParams {
-                strategy: tokenizers::PaddingStrategy::BatchLongest,
-                ..Default::default()
-            };
-            tokenizer.with_padding(Some(pp));
-        }
-        let tokens = tokenizer
-            .encode_batch(sentences.to_vec(), true)
-            .map_err(E::msg)?;
-        let token_ids = tokens
-            .iter()
-            .map(|tokens| {
-                let tokens = tokens.get_ids().to_vec();
-                Ok(Tensor::new(tokens.as_slice(), device)?)
-            })
-            .collect::<Result<Vec<_>>>()?;
-
-        let token_ids = Tensor::stack(&token_ids, 0)?;
-        let token_type_ids = token_ids.zeros_like()?;
-        println!("running inference on batch {:?}", token_ids.shape());
-        let embeddings = model.forward(&token_ids, &token_type_ids)?;
-        println!("generated embeddings {:?}", embeddings.shape());
-        // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
-        let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
-        let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
-        let embeddings = if args.normalize_embeddings {
-            normalize_l2(&embeddings)?
-        } else {
-            embeddings
-        };
-        println!("pooled embeddings {:?}", embeddings.shape());
-
-        let mut similarities = vec![];
-        for i in 0..n_sentences {
-            let e_i = embeddings.get(i)?;
-            for j in (i + 1)..n_sentences {
-                let e_j = embeddings.get(j)?;
-                let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
-                let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
-                let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
-                let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
-                similarities.push((cosine_similarity, i, j))
-            }
-        }
-        similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
-        for &(score, i, j) in similarities[..5].iter() {
-            println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
-        }
-    }
-    Ok(())
-}
-
-pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
-    Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
+fn main() {
+    println!("Hello World")
 }

+ 0 - 568
searcher/ci/model.rs

@@ -1,568 +0,0 @@
-use candle_core::{DType, Device, Result, Tensor};
-use candle_nn::{Embedding, VarBuilder};
-use serde::Deserialize;
-
-pub const DTYPE: DType = DType::F32;
-
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
-#[serde(rename_all = "lowercase")]
-enum HiddenAct {
-    Gelu,
-    Relu,
-}
-
-struct HiddenActLayer {
-    act: HiddenAct,
-    span: tracing::Span,
-}
-
-impl HiddenActLayer {
-    fn new(act: HiddenAct) -> Self {
-        let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
-        Self { act, span }
-    }
-
-    fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> {
-        let _enter = self.span.enter();
-        match self.act {
-            // TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some
-            // small numerical difference.
-            // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
-            HiddenAct::Gelu => xs.gelu(),
-            HiddenAct::Relu => xs.relu(),
-        }
-    }
-}
-
-#[derive(Debug)]
-pub struct Linear {
-    weight: Tensor,
-    bias: Option<Tensor>,
-    span: tracing::Span,
-}
-
-impl Linear {
-    pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
-        let span = tracing::span!(tracing::Level::TRACE, "linear");
-        Self { weight, bias, span }
-    }
-
-    pub fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
-        let _enter = self.span.enter();
-        let w = match x.dims() {
-            &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
-            _ => self.weight.t()?,
-        };
-        let x = x.matmul(&w)?;
-        match &self.bias {
-            None => Ok(x),
-            Some(bias) => x.broadcast_add(bias),
-        }
-    }
-}
-
-#[derive(Debug)]
-pub struct LayerNorm {
-    weight: Tensor,
-    bias: Tensor,
-    eps: f64,
-    span: tracing::Span,
-}
-
-impl LayerNorm {
-    pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
-        let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
-        Self {
-            weight,
-            bias,
-            eps,
-            span,
-        }
-    }
-
-    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
-        let _enter = self.span.enter();
-        let x_dtype = x.dtype();
-        let internal_dtype = match x_dtype {
-            DType::F16 | DType::BF16 => DType::F32,
-            d => d,
-        };
-        let (_bsize, _seq_len, hidden_size) = x.dims3()?;
-        let x = x.to_dtype(internal_dtype)?;
-        let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
-        let x = x.broadcast_sub(&mean_x)?;
-        let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
-        let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
-        let x = x_normed
-            .to_dtype(x_dtype)?
-            .broadcast_mul(&self.weight)?
-            .broadcast_add(&self.bias)?;
-        Ok(x)
-    }
-}
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
-#[serde(rename_all = "lowercase")]
-enum PositionEmbeddingType {
-    #[default]
-    Absolute,
-}
-
-// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
-#[derive(Debug, Clone, PartialEq, Deserialize)]
-pub struct Config {
-    vocab_size: usize,
-    hidden_size: usize,
-    num_hidden_layers: usize,
-    num_attention_heads: usize,
-    intermediate_size: usize,
-    hidden_act: HiddenAct,
-    hidden_dropout_prob: f64,
-    max_position_embeddings: usize,
-    type_vocab_size: usize,
-    initializer_range: f64,
-    layer_norm_eps: f64,
-    pad_token_id: usize,
-    #[serde(default)]
-    position_embedding_type: PositionEmbeddingType,
-    #[serde(default)]
-    use_cache: bool,
-    classifier_dropout: Option<f64>,
-    model_type: Option<String>,
-}
-
-impl Default for Config {
-    fn default() -> Self {
-        Self {
-            vocab_size: 30522,
-            hidden_size: 768,
-            num_hidden_layers: 12,
-            num_attention_heads: 12,
-            intermediate_size: 3072,
-            hidden_act: HiddenAct::Gelu,
-            hidden_dropout_prob: 0.1,
-            max_position_embeddings: 512,
-            type_vocab_size: 2,
-            initializer_range: 0.02,
-            layer_norm_eps: 1e-12,
-            pad_token_id: 0,
-            position_embedding_type: PositionEmbeddingType::Absolute,
-            use_cache: true,
-            classifier_dropout: None,
-            model_type: Some("bert".to_string()),
-        }
-    }
-}
-
-impl Config {
-    fn _all_mini_lm_l6_v2() -> Self {
-        // https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json
-        Self {
-            vocab_size: 30522,
-            hidden_size: 384,
-            num_hidden_layers: 6,
-            num_attention_heads: 12,
-            intermediate_size: 1536,
-            hidden_act: HiddenAct::Gelu,
-            hidden_dropout_prob: 0.1,
-            max_position_embeddings: 512,
-            type_vocab_size: 2,
-            initializer_range: 0.02,
-            layer_norm_eps: 1e-12,
-            pad_token_id: 0,
-            position_embedding_type: PositionEmbeddingType::Absolute,
-            use_cache: true,
-            classifier_dropout: None,
-            model_type: Some("bert".to_string()),
-        }
-    }
-}
-
-fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
-    let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
-    Ok(Embedding::new(embeddings, hidden_size))
-}
-
-fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
-    let weight = vb.get((size2, size1), "weight")?;
-    let bias = vb.get(size2, "bias")?;
-    Ok(Linear::new(weight, Some(bias)))
-}
-
-struct Dropout {
-    #[allow(dead_code)]
-    pr: f64,
-}
-
-impl Dropout {
-    fn new(pr: f64) -> Self {
-        Self { pr }
-    }
-
-    fn forward(&self, x: &Tensor) -> Result<Tensor> {
-        // TODO
-        Ok(x.clone())
-    }
-}
-
-fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
-    let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
-        (Ok(weight), Ok(bias)) => (weight, bias),
-        (Err(err), _) | (_, Err(err)) => {
-            if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
-                (weight, bias)
-            } else {
-                return Err(err);
-            }
-        }
-    };
-    Ok(LayerNorm::new(weight, bias, eps))
-}
-
-// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
-struct BertEmbeddings {
-    word_embeddings: Embedding,
-    position_embeddings: Option<Embedding>,
-    token_type_embeddings: Embedding,
-    layer_norm: LayerNorm,
-    dropout: Dropout,
-    span: tracing::Span,
-}
-
-impl BertEmbeddings {
-    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
-        let word_embeddings = embedding(
-            config.vocab_size,
-            config.hidden_size,
-            vb.pp("word_embeddings"),
-        )?;
-        let position_embeddings = embedding(
-            config.max_position_embeddings,
-            config.hidden_size,
-            vb.pp("position_embeddings"),
-        )?;
-        let token_type_embeddings = embedding(
-            config.type_vocab_size,
-            config.hidden_size,
-            vb.pp("token_type_embeddings"),
-        )?;
-        let layer_norm = layer_norm(
-            config.hidden_size,
-            config.layer_norm_eps,
-            vb.pp("LayerNorm"),
-        )?;
-        Ok(Self {
-            word_embeddings,
-            position_embeddings: Some(position_embeddings),
-            token_type_embeddings,
-            layer_norm,
-            dropout: Dropout::new(config.hidden_dropout_prob),
-            span: tracing::span!(tracing::Level::TRACE, "embeddings"),
-        })
-    }
-
-    fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
-        let _enter = self.span.enter();
-        let (_bsize, seq_len) = input_ids.dims2()?;
-        let input_embeddings = self.word_embeddings.forward(input_ids)?;
-        let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
-        let mut embeddings = (&input_embeddings + token_type_embeddings)?;
-        if let Some(position_embeddings) = &self.position_embeddings {
-            // TODO: Proper absolute positions?
-            let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
-            let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
-            embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
-        }
-        let embeddings = self.layer_norm.forward(&embeddings)?;
-        let embeddings = self.dropout.forward(&embeddings)?;
-        Ok(embeddings)
-    }
-}
-
-struct BertSelfAttention {
-    query: Linear,
-    key: Linear,
-    value: Linear,
-    dropout: Dropout,
-    num_attention_heads: usize,
-    attention_head_size: usize,
-    span: tracing::Span,
-    span_softmax: tracing::Span,
-}
-
-impl BertSelfAttention {
-    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
-        let attention_head_size = config.hidden_size / config.num_attention_heads;
-        let all_head_size = config.num_attention_heads * attention_head_size;
-        let dropout = Dropout::new(config.hidden_dropout_prob);
-        let hidden_size = config.hidden_size;
-        let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
-        let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
-        let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
-        Ok(Self {
-            query,
-            key,
-            value,
-            dropout,
-            num_attention_heads: config.num_attention_heads,
-            attention_head_size,
-            span: tracing::span!(tracing::Level::TRACE, "self-attn"),
-            span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
-        })
-    }
-
-    fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
-        let mut new_x_shape = xs.dims().to_vec();
-        new_x_shape.pop();
-        new_x_shape.push(self.num_attention_heads);
-        new_x_shape.push(self.attention_head_size);
-        let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
-        xs.contiguous()
-    }
-
-    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
-        let _enter = self.span.enter();
-        let query_layer = self.query.forward(hidden_states)?;
-        let key_layer = self.key.forward(hidden_states)?;
-        let value_layer = self.value.forward(hidden_states)?;
-
-        let query_layer = self.transpose_for_scores(&query_layer)?;
-        let key_layer = self.transpose_for_scores(&key_layer)?;
-        let value_layer = self.transpose_for_scores(&value_layer)?;
-
-        let attention_scores = query_layer.matmul(&key_layer.t()?)?;
-        let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
-        let attention_probs = {
-            let _enter_sm = self.span_softmax.enter();
-            candle_nn::ops::softmax(&attention_scores, candle_core::D::Minus1)?
-        };
-        let attention_probs = self.dropout.forward(&attention_probs)?;
-
-        let context_layer = attention_probs.matmul(&value_layer)?;
-        let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
-        let context_layer = context_layer.flatten_from(candle_core::D::Minus2)?;
-        Ok(context_layer)
-    }
-}
-
-struct BertSelfOutput {
-    dense: Linear,
-    layer_norm: LayerNorm,
-    dropout: Dropout,
-    span: tracing::Span,
-}
-
-impl BertSelfOutput {
-    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
-        let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
-        let layer_norm = layer_norm(
-            config.hidden_size,
-            config.layer_norm_eps,
-            vb.pp("LayerNorm"),
-        )?;
-        let dropout = Dropout::new(config.hidden_dropout_prob);
-        Ok(Self {
-            dense,
-            layer_norm,
-            dropout,
-            span: tracing::span!(tracing::Level::TRACE, "self-out"),
-        })
-    }
-
-    fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
-        let _enter = self.span.enter();
-        let hidden_states = self.dense.forward(hidden_states)?;
-        let hidden_states = self.dropout.forward(&hidden_states)?;
-        self.layer_norm.forward(&(hidden_states + input_tensor)?)
-    }
-}
-
-// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392
-struct BertAttention {
-    self_attention: BertSelfAttention,
-    self_output: BertSelfOutput,
-    span: tracing::Span,
-}
-
-impl BertAttention {
-    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
-        let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
-        let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
-        Ok(Self {
-            self_attention,
-            self_output,
-            span: tracing::span!(tracing::Level::TRACE, "attn"),
-        })
-    }
-
-    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
-        let _enter = self.span.enter();
-        let self_outputs = self.self_attention.forward(hidden_states)?;
-        let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
-        Ok(attention_output)
-    }
-}
-
-// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441
-struct BertIntermediate {
-    dense: Linear,
-    intermediate_act: HiddenActLayer,
-    span: tracing::Span,
-}
-
-impl BertIntermediate {
-    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
-        let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
-        Ok(Self {
-            dense,
-            intermediate_act: HiddenActLayer::new(config.hidden_act),
-            span: tracing::span!(tracing::Level::TRACE, "inter"),
-        })
-    }
-
-    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
-        let _enter = self.span.enter();
-        let hidden_states = self.dense.forward(hidden_states)?;
-        let ys = self.intermediate_act.forward(&hidden_states)?;
-        Ok(ys)
-    }
-}
-
-// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456
-struct BertOutput {
-    dense: Linear,
-    layer_norm: LayerNorm,
-    dropout: Dropout,
-    span: tracing::Span,
-}
-
-impl BertOutput {
-    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
-        let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
-        let layer_norm = layer_norm(
-            config.hidden_size,
-            config.layer_norm_eps,
-            vb.pp("LayerNorm"),
-        )?;
-        let dropout = Dropout::new(config.hidden_dropout_prob);
-        Ok(Self {
-            dense,
-            layer_norm,
-            dropout,
-            span: tracing::span!(tracing::Level::TRACE, "out"),
-        })
-    }
-
-    fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
-        let _enter = self.span.enter();
-        let hidden_states = self.dense.forward(hidden_states)?;
-        let hidden_states = self.dropout.forward(&hidden_states)?;
-        self.layer_norm.forward(&(hidden_states + input_tensor)?)
-    }
-}
-
-// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470
-struct BertLayer {
-    attention: BertAttention,
-    intermediate: BertIntermediate,
-    output: BertOutput,
-    span: tracing::Span,
-}
-
-impl BertLayer {
-    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
-        let attention = BertAttention::load(vb.pp("attention"), config)?;
-        let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
-        let output = BertOutput::load(vb.pp("output"), config)?;
-        Ok(Self {
-            attention,
-            intermediate,
-            output,
-            span: tracing::span!(tracing::Level::TRACE, "layer"),
-        })
-    }
-
-    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
-        let _enter = self.span.enter();
-        let attention_output = self.attention.forward(hidden_states)?;
-        // TODO: Support cross-attention?
-        // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
-        // TODO: Support something similar to `apply_chunking_to_forward`?
-        let intermediate_output = self.intermediate.forward(&attention_output)?;
-        let layer_output = self
-            .output
-            .forward(&intermediate_output, &attention_output)?;
-        Ok(layer_output)
-    }
-}
-
-// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
-struct BertEncoder {
-    layers: Vec<BertLayer>,
-    span: tracing::Span,
-}
-
-impl BertEncoder {
-    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
-        let layers = (0..config.num_hidden_layers)
-            .map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config))
-            .collect::<Result<Vec<_>>>()?;
-        let span = tracing::span!(tracing::Level::TRACE, "encoder");
-        Ok(BertEncoder { layers, span })
-    }
-
-    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
-        let _enter = self.span.enter();
-        let mut hidden_states = hidden_states.clone();
-        // Use a loop rather than a fold as it's easier to modify when adding debug/...
-        for layer in self.layers.iter() {
-            hidden_states = layer.forward(&hidden_states)?
-        }
-        Ok(hidden_states)
-    }
-}
-
-// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
-pub struct BertModel {
-    embeddings: BertEmbeddings,
-    encoder: BertEncoder,
-    pub device: Device,
-    span: tracing::Span,
-}
-
-impl BertModel {
-    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
-        let (embeddings, encoder) = match (
-            BertEmbeddings::load(vb.pp("embeddings"), config),
-            BertEncoder::load(vb.pp("encoder"), config),
-        ) {
-            (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
-            (Err(err), _) | (_, Err(err)) => {
-                if let Some(model_type) = &config.model_type {
-                    if let (Ok(embeddings), Ok(encoder)) = (
-                        BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
-                        BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config),
-                    ) {
-                        (embeddings, encoder)
-                    } else {
-                        return Err(err);
-                    }
-                } else {
-                    return Err(err);
-                }
-            }
-        };
-        Ok(Self {
-            embeddings,
-            encoder,
-            device: vb.device().clone(),
-            span: tracing::span!(tracing::Level::TRACE, "model"),
-        })
-    }
-
-    pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
-        let _enter = self.span.enter();
-        let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
-        let sequence_output = self.encoder.forward(&embedding_output)?;
-        Ok(sequence_output)
-    }
-}

+ 10 - 0
searcher/src/args.rs

@@ -11,6 +11,10 @@ use tokenizers::Tokenizer;
 #[derive(Parser, Debug)]
 #[command(author, version, about, long_about = None)]
 pub struct Args {
+    /// teminate immediately (just for downloading BERT model)
+    #[arg(long)]
+    ci: bool,
+
     /// Run offline (you must have the files already cached)
     #[arg(long)]
     offline: bool,
@@ -24,6 +28,12 @@ pub struct Args {
 }
 
 impl Args {
+    pub fn terminate_if_ci(&self) {
+        if self.ci {
+            println!("terminating ci mode");
+            std::process::exit(0)
+        }
+    }
     pub fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
         let device = Device::Cpu;
         let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();

+ 5 - 1
searcher/src/main.rs

@@ -1,3 +1,5 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
 mod args;
 mod embed;
 
@@ -25,6 +27,8 @@ async fn main() -> std::io::Result<()> {
     let enc = encoder::Encoder::new(model, tokenizer);
     let mutexed_enc = web::Data::new(enc);
 
+    args.terminate_if_ci();
+
     println!("Listen on 0.0.0.0:8080");
     let result = HttpServer::new(move || {
         App::new()
@@ -36,5 +40,5 @@ async fn main() -> std::io::Result<()> {
     .run()
     .await;
     println!("Server Terminated. Byebye :-)");
-    return result;
+    result
 }