Explorar el Código

fix(searcher): no padding & normalize resuting vec

iwanhae hace 1 año
padre
commit
fd920595da
Se han modificado 1 ficheros con 8 adiciones y 4 borrados
  1. 8 4
      searcher/src/embed/encoder.rs

+ 8 - 4
searcher/src/embed/encoder.rs

@@ -19,6 +19,8 @@ impl Encoder {
         let start = Instant::now();
         let tokens = self
             .tokenizer
+            .clone()
+            .with_padding(None)
             .encode(prompt, true)
             .map_err(E::msg)
             .unwrap()
@@ -30,10 +32,12 @@ impl Encoder {
         let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
 
         // mean pooling
-        assert_eq!(embeddings.shape().dims(), [1, 128, 384]);
-        let embeddings = embeddings.sum(1)? / (n_tokens as f64);
-
+        let embeddings = (embeddings.sum(1)? / (n_tokens as f64)).unwrap();
         dbg!(prompt, start.elapsed());
-        embeddings?.get(0)
+        normalize_l2(&embeddings)?.get(0)
     }
 }
+
+pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
+    Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
+}