|
@@ -1,5 +1,5 @@
|
|
|
-use candle_core::{DType, Device, Result, Tensor};
|
|
|
-use candle_nn::{Embedding, VarBuilder};
|
|
|
+use candle::{DType, Device, Result, Tensor};
|
|
|
+use candle_nn::{Embedding, Module, VarBuilder};
|
|
|
use serde::Deserialize;
|
|
|
|
|
|
pub const DTYPE: DType = DType::F32;
|
|
@@ -22,7 +22,7 @@ impl HiddenActLayer {
|
|
|
Self { act, span }
|
|
|
}
|
|
|
|
|
|
- fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> {
|
|
|
+ fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
|
|
let _enter = self.span.enter();
|
|
|
match self.act {
|
|
|
// TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some
|
|
@@ -47,7 +47,7 @@ impl Linear {
|
|
|
Self { weight, bias, span }
|
|
|
}
|
|
|
|
|
|
- pub fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
|
|
+ pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
|
|
let _enter = self.span.enter();
|
|
|
let w = match x.dims() {
|
|
|
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
|
@@ -333,13 +333,13 @@ impl BertSelfAttention {
|
|
|
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
|
|
|
let attention_probs = {
|
|
|
let _enter_sm = self.span_softmax.enter();
|
|
|
- candle_nn::ops::softmax(&attention_scores, candle_core::D::Minus1)?
|
|
|
+ candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
|
|
|
};
|
|
|
let attention_probs = self.dropout.forward(&attention_probs)?;
|
|
|
|
|
|
let context_layer = attention_probs.matmul(&value_layer)?;
|
|
|
let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
|
|
|
- let context_layer = context_layer.flatten_from(candle_core::D::Minus2)?;
|
|
|
+ let context_layer = context_layer.flatten_from(candle::D::Minus2)?;
|
|
|
Ok(context_layer)
|
|
|
}
|
|
|
}
|