Переглянути джерело

fix(searcher): use candle 0.1.2

iwanhae 1 рік тому
батько
коміт
196a6df451

+ 3 - 3
searcher/Cargo.toml

@@ -6,8 +6,8 @@ edition = "2021"
 # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 
 [dependencies]
-candle-core = { version = "0.1.0" }
-candle-nn = { version = "0.1.0" }
+candle = { package = "candle-core", version = "0.1.2" }
+candle-nn = { version = "0.1.2" }
 tokenizers = { version = "0.13.3", default-features = true }
 anyhow = { version = "1", features = ["backtrace"] }
 serde = { version = "1.0.171", features = ["derive"] }
@@ -23,4 +23,4 @@ intel-mkl-src = { version = "0.8.1", features = [
 
 [features]
 default = []
-mkl = ["dep:intel-mkl-src", "candle-core/mkl", "candle-nn/mkl"]
+mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]

+ 2 - 2
searcher/src/args.rs

@@ -1,7 +1,7 @@
 use crate::embed;
 
 use anyhow::{anyhow, Error as E, Result};
-use candle_core::Device;
+use candle::Device;
 use candle_nn::VarBuilder;
 use clap::Parser;
 use embed::model::{BertModel, Config, DTYPE};
@@ -74,7 +74,7 @@ impl Args {
         let config: Config = serde_json::from_str(&config)?;
         let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
 
-        let weights = unsafe { candle_core::safetensors::MmapedFile::new(weights_filename)? };
+        let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
         let weights = weights.deserialize()?;
         let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
         let model = BertModel::load(vb, &config)?;

+ 1 - 1
searcher/src/embed/encoder.rs

@@ -1,7 +1,7 @@
 use super::model;
 
 use anyhow::Error as E;
-use candle_core::{Result, Tensor};
+use candle::{Result, Tensor};
 use std::time::Instant;
 use tokenizers::Tokenizer;
 

+ 6 - 6
searcher/src/embed/model.rs

@@ -1,5 +1,5 @@
-use candle_core::{DType, Device, Result, Tensor};
-use candle_nn::{Embedding, VarBuilder};
+use candle::{DType, Device, Result, Tensor};
+use candle_nn::{Embedding, Module, VarBuilder};
 use serde::Deserialize;
 
 pub const DTYPE: DType = DType::F32;
@@ -22,7 +22,7 @@ impl HiddenActLayer {
         Self { act, span }
     }
 
-    fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> {
+    fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
         let _enter = self.span.enter();
         match self.act {
             // TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some
@@ -47,7 +47,7 @@ impl Linear {
         Self { weight, bias, span }
     }
 
-    pub fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
+    pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
         let _enter = self.span.enter();
         let w = match x.dims() {
             &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
@@ -333,13 +333,13 @@ impl BertSelfAttention {
         let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
         let attention_probs = {
             let _enter_sm = self.span_softmax.enter();
-            candle_nn::ops::softmax(&attention_scores, candle_core::D::Minus1)?
+            candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
         };
         let attention_probs = self.dropout.forward(&attention_probs)?;
 
         let context_layer = attention_probs.matmul(&value_layer)?;
         let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
-        let context_layer = context_layer.flatten_from(candle_core::D::Minus2)?;
+        let context_layer = context_layer.flatten_from(candle::D::Minus2)?;
         Ok(context_layer)
     }
 }