Browse Source

feat(searcher): simple temporary api for embedding

iwanhae 1 year ago
parent
commit
d9b5a035b4
4 changed files with 30 additions and 21 deletions
  1. 1 0
      searcher/Cargo.toml
  2. 0 12
      searcher/src/args.rs
  3. 1 1
      searcher/src/embed/encoder.rs
  4. 28 8
      searcher/src/main.rs

+ 1 - 0
searcher/Cargo.toml

@@ -17,3 +17,4 @@ serde_json = "1.0.99"
 tracing = "0.1.37"
 hf-hub = "0.2.0"
 clap = { version = "4.2.4", features = ["derive"] }
+actix-web = "4"

+ 0 - 12
searcher/src/args.rs

@@ -24,18 +24,6 @@ pub struct Args {
 
     #[arg(long)]
     revision: Option<String>,
-
-    /// When set, compute embeddings for this prompt.
-    #[arg(long)]
-    prompt: Option<String>,
-
-    /// The number of times to run the prompt.
-    #[arg(long, default_value = "1")]
-    n: usize,
-
-    /// L2 normalization for embeddings.
-    #[arg(long, default_value = "true")]
-    normalize_embeddings: bool,
 }
 
 impl Args {

+ 1 - 1
searcher/src/embed/encoder.rs

@@ -33,7 +33,7 @@ impl Encoder {
         assert_eq!(embeddings.shape().dims(), [1, 128, 384]);
         let embeddings = embeddings.sum(1)? / (n_tokens as f64);
 
-        dbg!(start.elapsed());
+        dbg!(prompt, start.elapsed());
         embeddings?.get(0)
     }
 }

+ 28 - 8
searcher/src/main.rs

@@ -1,19 +1,39 @@
 mod args;
 mod embed;
 
-use anyhow::Result;
+use actix_web::{get, web, App, HttpResponse, HttpServer, Responder};
 use args::Args;
 use clap::Parser;
 use embed::encoder;
 
-fn main() -> Result<()> {
+#[get("/")]
+async fn hello() -> impl Responder {
+    HttpResponse::Ok().body(":-)")
+}
+
+#[get("/q/{query}")]
+async fn search(path: web::Path<String>, enc: web::Data<encoder::Encoder>) -> impl Responder {
+    let q = path.to_string();
+    let vec = enc.encode(&q).unwrap().to_string();
+    HttpResponse::Ok().body(vec)
+}
+
+#[actix_web::main]
+async fn main() -> std::io::Result<()> {
     let args = Args::parse();
-    let (model, tokenizer) = args.build_model_and_tokenizer()?;
+    let (model, tokenizer) = args.build_model_and_tokenizer().unwrap();
     let enc = encoder::Encoder::new(model, tokenizer);
+    let mutexed_enc = web::Data::new(enc);
 
-    for _ in 1..10 {
-        let ys = enc.encode("Hello World").unwrap();
-        println!("{:?}", ys.shape())
-    }
-    Ok(())
+    let result = HttpServer::new(move || {
+        App::new()
+            .app_data(mutexed_enc.clone())
+            .service(hello)
+            .service(search)
+    })
+    .bind(("0.0.0.0", 8080))?
+    .run()
+    .await;
+    println!("Server Terminated. Byebye :-)");
+    return result;
 }