4 Commits a57783f2b8 ... 4bc303aa09

Author SHA1 Message Date
  iwanhae 4bc303aa09 feat(searcher): a little query optimization 1 year ago
  iwanhae c2368f756c feat(searcher): basic logger 1 year ago
  iwanhae 54c24fd4d9 feat(searcher): search api 1 year ago
  iwanhae a57783f2b8 tmp 1 year ago

+ 2 - 0
searcher/Cargo.toml

@@ -25,6 +25,8 @@ rusqlite = { version = "0.29.0", features = ["bundled", "trace", "array"] }
 usearch = "1.1.1"
 actix = "0.13.0"
 cxx = "1.0.106"
+env_logger = "0.10.0"
+log = "0.4.20"
 
 
 [features]

+ 1 - 1
searcher/src/args.rs

@@ -30,7 +30,7 @@ pub struct Args {
 impl Args {
     pub fn terminate_if_ci(&self) {
         if self.ci {
-            println!("terminating ci mode");
+            info!("terminating ci mode");
             std::process::exit(0)
         }
     }

+ 1 - 3
searcher/src/database/mod.rs

@@ -6,7 +6,5 @@ use r2d2_sqlite::SqliteConnectionManager;
 pub fn new_connection_pool() -> Pool<SqliteConnectionManager> {
     let database_url = env::var("DATABASE_URL").unwrap_or_else(|_| String::from("./kuberian.db"));
     let manager = SqliteConnectionManager::file(database_url);
-    let pool = Pool::new(manager).expect("fail to generate db connection pool");
-
-    pool
+    Pool::new(manager).expect("fail to generate db connection pool")
 }

+ 6 - 5
searcher/src/embed/encoder.rs

@@ -1,7 +1,7 @@
 use super::model;
 
 use anyhow::Error as E;
-use candle::{Result, Tensor};
+use candle::Tensor;
 use std::time::Instant;
 use tokenizers::Tokenizer;
 
@@ -36,8 +36,8 @@ impl Encoder {
 
         // mean pooling
         let embeddings = (embeddings.sum(1).unwrap() / (n_tokens as f64)).unwrap();
-        dbg!(prompt, start.elapsed());
-        let tensor = normalize_l2(&embeddings).unwrap().get(0).unwrap();
+        debug!("{:?} {}", start.elapsed(), prompt);
+        let tensor = normalize_l2(&embeddings).get(0).unwrap();
 
         let mut result: Vec<f32> = Vec::new();
         for i in 0..384 {
@@ -47,6 +47,7 @@ impl Encoder {
     }
 }
 
-pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
-    Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
+pub fn normalize_l2(v: &Tensor) -> Tensor {
+    v.broadcast_div(&v.sqr().unwrap().sum_keepdim(1).unwrap().sqrt().unwrap())
+        .unwrap()
 }

+ 15 - 3
searcher/src/main.rs

@@ -1,18 +1,23 @@
 #[cfg(feature = "mkl")]
 extern crate intel_mkl_src;
+
+#[macro_use]
+extern crate log;
+
 mod args;
 mod database;
 mod embed;
 mod server;
 mod vector_db;
 
-use actix_web::{web, App, HttpServer};
+use actix_web::{middleware, web, App, HttpServer};
 use args::Args;
 use clap::Parser;
 use embed::encoder;
 
 #[actix_web::main]
 async fn main() -> std::io::Result<()> {
+    env_logger::init_from_env(env_logger::Env::new().default_filter_or("debug"));
     let args = Args::parse();
 
     // [BEGIN] INIT
@@ -21,30 +26,37 @@ async fn main() -> std::io::Result<()> {
     let (model, tokenizer) = args.build_model_and_tokenizer().unwrap();
     let enc = encoder::Encoder::new(model, tokenizer);
     let app_data_encoder = web::Data::new(enc);
+    info!("MODEL loaded");
 
     // DATABSE
     let database = database::new_connection_pool();
     let app_data_database = web::Data::new(database);
+    info!("DATABASE loaded");
 
     // VECTOR DATABASE
     let vector = vector_db::start();
     let app_data_vector = web::Data::new(vector);
+    info!("VECTOR DATABASE loaded");
 
     // [END] INIT
     args.terminate_if_ci();
 
-    println!("Listen on 0.0.0.0:8080");
+    info!("Listen on 0.0.0.0:8080");
+
     let result = HttpServer::new(move || {
         App::new()
+            .wrap(middleware::Logger::default())
+            .wrap(middleware::Compress::default())
             .app_data(app_data_encoder.clone())
             .app_data(app_data_database.clone())
             .app_data(app_data_vector.clone())
+            .service(server::handler::healthz)
             .service(server::handler::hello)
             .service(server::handler::search)
     })
     .bind(("0.0.0.0", 8080))?
     .run()
     .await;
-    println!("Server Terminated. Byebye :-)");
+    info!("Server Terminated. Byebye :-)");
     result
 }

+ 50 - 16
searcher/src/server/handler.rs

@@ -1,10 +1,11 @@
+use std::sync::atomic::{AtomicU64, Ordering};
+
 use super::types;
 use actix::Addr;
 use actix_web::{get, web::Data, web::Path, HttpResponse, Responder};
 use r2d2::Pool;
 use r2d2_sqlite::SqliteConnectionManager;
 use rusqlite::params_from_iter;
-use serde_json::json;
 
 use crate::{
     embed::encoder::Encoder,
@@ -12,6 +13,11 @@ use crate::{
     vector_db::{self, VectorDB},
 };
 
+#[get("/healthz")]
+async fn healthz() -> impl Responder {
+    HttpResponse::Ok().body("ok")
+}
+
 #[get("/")]
 async fn hello(pool: Data<Pool<SqliteConnectionManager>>) -> impl Responder {
     let conn = pool.get().unwrap();
@@ -30,7 +36,10 @@ async fn hello(pool: Data<Pool<SqliteConnectionManager>>) -> impl Responder {
 
     let mut samples: Vec<String> = vec![];
     for row in conn
-        .prepare_cached("SELECT summary FROM function_analyses ORDER BY RANDOM() LIMIT 10;")
+        .prepare_cached(
+            r"SELECT summary FROM function_analyses WHERE id 
+            IN (SELECT id FROM function_analyses ORDER BY RANDOM() LIMIT 10);",
+        )
         .unwrap()
         .query_map([], |row| row.get::<usize, String>(0))
         .unwrap()
@@ -42,7 +51,7 @@ async fn hello(pool: Data<Pool<SqliteConnectionManager>>) -> impl Responder {
     HttpResponse::Ok().json(types::ResponseHello {
         total: count_functions,
         analyzed: count_analyses,
-        samples: samples,
+        samples,
     })
 }
 
@@ -58,20 +67,45 @@ async fn search(
     let q = path.to_string();
     let embeddings = enc.encode(&q);
     let doc_ids = vecdb.send(vector_db::Query(embeddings)).await.unwrap();
-    let mut docs: Vec<String> = Vec::new();
-    for doc in conn
-        .prepare_cached(&format!(
-            "SELECT id, name, signature FROM functions WHERE id IN ({});",
-            repeat_vars(doc_ids.len())
-        ))
-        .unwrap()
-        .query_map(params_from_iter(doc_ids.iter()), |row| {
-            row.get::<usize, String>(2)
+    let mut docs: Vec<types::DocumentSummary> = Vec::new();
+    conn.prepare_cached(&format!(
+        r#"
+            SELECT src.id, name, signature, file, line_start, line_end, tgt.summary 
+            FROM functions as src 
+                LEFT JOIN function_analyses as tgt 
+                ON src.id = tgt.function_id 
+            WHERE src.id IN ({});
+            "#,
+        repeat_vars(doc_ids.len())
+    ))
+    .expect("fail to query on database")
+    .query_map(params_from_iter(doc_ids.iter().map(|x| x.0)), |row| {
+        Ok(types::DocumentSummary {
+            id: row.get(0).unwrap(),
+            name: row.get(1).unwrap(),
+            signature: row.get(2).unwrap(),
+            file: row.get(3).unwrap(),
+            line: types::LineInfo {
+                start: row.get(4).unwrap(),
+                end: row.get(5).unwrap(),
+            },
+            summary: row.get(6).unwrap(),
         })
-        .unwrap()
-    {
+    })
+    .expect("unexpected result from database")
+    .for_each(|doc| {
         docs.push(doc.unwrap());
-    }
+    });
 
-    HttpResponse::Ok().json(json!({"result": doc_ids, "test": docs}))
+    HttpResponse::Ok().json(types::ResponseSearch {
+        query: q,
+        results: doc_ids
+            .iter()
+            .map(|x| types::SearchResult {
+                id: x.0,
+                score: x.1,
+            })
+            .collect(),
+        docs,
+    })
 }

+ 29 - 0
searcher/src/server/types.rs

@@ -6,3 +6,32 @@ pub struct ResponseHello {
     pub analyzed: u64,
     pub samples: Vec<String>,
 }
+
+#[derive(Serialize)]
+pub struct ResponseSearch {
+    pub query: String,
+    pub docs: Vec<DocumentSummary>,
+    pub results: Vec<SearchResult>,
+}
+
+#[derive(Serialize)]
+pub struct SearchResult {
+    pub id: u64,
+    pub score: f32,
+}
+
+#[derive(Serialize)]
+pub struct DocumentSummary {
+    pub id: u64,
+    pub name: String,
+    pub signature: String,
+    pub file: String,
+    pub line: LineInfo,
+    pub summary: Option<String>,
+}
+
+#[derive(Serialize)]
+pub struct LineInfo {
+    pub start: i32,
+    pub end: i32,
+}

+ 12 - 8
searcher/src/vector_db/mod.rs

@@ -3,7 +3,7 @@ use cxx::UniquePtr;
 use usearch::ffi::{new_index, Index, IndexOptions, MetricKind, ScalarKind};
 
 #[derive(Message)]
-#[rtype(result = "Vec<i32>")]
+#[rtype(result = "Vec<(u64, f32)>")]
 pub struct Query(pub Vec<f32>);
 
 pub struct VectorDB {
@@ -22,7 +22,7 @@ pub fn start() -> Addr<VectorDB> {
     .unwrap();
 
     index
-        .load("./kuberian.usearch")
+        .view("./kuberian.usearch")
         .expect("fail to load usearch index");
 
     VectorDB { index }.start()
@@ -31,20 +31,24 @@ pub fn start() -> Addr<VectorDB> {
 impl Actor for VectorDB {
     type Context = Context<Self>;
 
-    fn started(&mut self, _ctx: &mut Context<Self>) {}
+    fn started(&mut self, _ctx: &mut Context<Self>) {
+        info!("vector db service initiated");
+    }
 
-    fn stopped(&mut self, _ctx: &mut Context<Self>) {}
+    fn stopped(&mut self, _ctx: &mut Context<Self>) {
+        error!("vector db service shutdown")
+    }
 }
 
 impl Handler<Query> for VectorDB {
-    type Result = Vec<i32>;
+    type Result = Vec<(u64, f32)>;
 
     fn handle(&mut self, msg: Query, _ctx: &mut Context<Self>) -> Self::Result {
         let result = self.index.search(&msg.0, 10).unwrap();
-        let mut converted: Vec<i32> = Vec::new();
+        let mut converted: Vec<(u64, f32)> = Vec::new();
 
-        for val in result.keys.iter() {
-            converted.push(i32::try_from(*val).unwrap())
+        for (i, val) in result.keys.iter().enumerate() {
+            converted.push((*val, *result.distances.get(i).unwrap()))
         }
 
         converted