| 
														
															@@ -1,5 +1,5 @@ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-use candle_core::{DType, Device, Result, Tensor}; 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-use candle_nn::{Embedding, VarBuilder}; 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+use candle::{DType, Device, Result, Tensor}; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+use candle_nn::{Embedding, Module, VarBuilder}; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 use serde::Deserialize; 
														 | 
														
														 | 
														
															 use serde::Deserialize; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 pub const DTYPE: DType = DType::F32; 
														 | 
														
														 | 
														
															 pub const DTYPE: DType = DType::F32; 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -22,7 +22,7 @@ impl HiddenActLayer { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         Self { act, span } 
														 | 
														
														 | 
														
															         Self { act, span } 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     } 
														 | 
														
														 | 
														
															     } 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> { 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         let _enter = self.span.enter(); 
														 | 
														
														 | 
														
															         let _enter = self.span.enter(); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         match self.act { 
														 | 
														
														 | 
														
															         match self.act { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             // TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some 
														 | 
														
														 | 
														
															             // TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -47,7 +47,7 @@ impl Linear { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         Self { weight, bias, span } 
														 | 
														
														 | 
														
															         Self { weight, bias, span } 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     } 
														 | 
														
														 | 
														
															     } 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    pub fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> { 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         let _enter = self.span.enter(); 
														 | 
														
														 | 
														
															         let _enter = self.span.enter(); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         let w = match x.dims() { 
														 | 
														
														 | 
														
															         let w = match x.dims() { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, 
														 | 
														
														 | 
														
															             &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -333,13 +333,13 @@ impl BertSelfAttention { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; 
														 | 
														
														 | 
														
															         let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         let attention_probs = { 
														 | 
														
														 | 
														
															         let attention_probs = { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             let _enter_sm = self.span_softmax.enter(); 
														 | 
														
														 | 
														
															             let _enter_sm = self.span_softmax.enter(); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            candle_nn::ops::softmax(&attention_scores, candle_core::D::Minus1)? 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)? 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         }; 
														 | 
														
														 | 
														
															         }; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         let attention_probs = self.dropout.forward(&attention_probs)?; 
														 | 
														
														 | 
														
															         let attention_probs = self.dropout.forward(&attention_probs)?; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															         let context_layer = attention_probs.matmul(&value_layer)?; 
														 | 
														
														 | 
														
															         let context_layer = attention_probs.matmul(&value_layer)?; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         let context_layer = context_layer.transpose(1, 2)?.contiguous()?; 
														 | 
														
														 | 
														
															         let context_layer = context_layer.transpose(1, 2)?.contiguous()?; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        let context_layer = context_layer.flatten_from(candle_core::D::Minus2)?; 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        let context_layer = context_layer.flatten_from(candle::D::Minus2)?; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         Ok(context_layer) 
														 | 
														
														 | 
														
															         Ok(context_layer) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     } 
														 | 
														
														 | 
														
															     } 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 } 
														 | 
														
														 | 
														
															 } 
														 |