|
@@ -19,6 +19,8 @@ impl Encoder {
|
|
|
let start = Instant::now();
|
|
|
let tokens = self
|
|
|
.tokenizer
|
|
|
+ .clone()
|
|
|
+ .with_padding(None)
|
|
|
.encode(prompt, true)
|
|
|
.map_err(E::msg)
|
|
|
.unwrap()
|
|
@@ -30,10 +32,12 @@ impl Encoder {
|
|
|
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
|
|
|
|
|
// mean pooling
|
|
|
- assert_eq!(embeddings.shape().dims(), [1, 128, 384]);
|
|
|
- let embeddings = embeddings.sum(1)? / (n_tokens as f64);
|
|
|
-
|
|
|
+ let embeddings = (embeddings.sum(1)? / (n_tokens as f64)).unwrap();
|
|
|
dbg!(prompt, start.elapsed());
|
|
|
- embeddings?.get(0)
|
|
|
+ normalize_l2(&embeddings)?.get(0)
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
|
|
+ Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
|
|
+}
|