Forráskód Böngészése

feat(searcher): search api

iwanhae 1 éve
szülő
commit
54c24fd4d9

+ 1 - 3
searcher/src/database/mod.rs

@@ -6,7 +6,5 @@ use r2d2_sqlite::SqliteConnectionManager;
 pub fn new_connection_pool() -> Pool<SqliteConnectionManager> {
     let database_url = env::var("DATABASE_URL").unwrap_or_else(|_| String::from("./kuberian.db"));
     let manager = SqliteConnectionManager::file(database_url);
-    let pool = Pool::new(manager).expect("fail to generate db connection pool");
-
-    pool
+    Pool::new(manager).expect("fail to generate db connection pool")
 }

+ 5 - 4
searcher/src/embed/encoder.rs

@@ -1,7 +1,7 @@
 use super::model;
 
 use anyhow::Error as E;
-use candle::{Result, Tensor};
+use candle::Tensor;
 use std::time::Instant;
 use tokenizers::Tokenizer;
 
@@ -37,7 +37,7 @@ impl Encoder {
         // mean pooling
         let embeddings = (embeddings.sum(1).unwrap() / (n_tokens as f64)).unwrap();
         dbg!(prompt, start.elapsed());
-        let tensor = normalize_l2(&embeddings).unwrap().get(0).unwrap();
+        let tensor = normalize_l2(&embeddings).get(0).unwrap();
 
         let mut result: Vec<f32> = Vec::new();
         for i in 0..384 {
@@ -47,6 +47,7 @@ impl Encoder {
     }
 }
 
-pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
-    Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
+pub fn normalize_l2(v: &Tensor) -> Tensor {
+    v.broadcast_div(&v.sqr().unwrap().sum_keepdim(1).unwrap().sqrt().unwrap())
+        .unwrap()
 }

+ 55 - 18
searcher/src/server/handler.rs

@@ -1,9 +1,9 @@
+use super::types;
 use actix::Addr;
 use actix_web::{get, web::Data, web::Path, HttpResponse, Responder};
 use r2d2::Pool;
 use r2d2_sqlite::SqliteConnectionManager;
 use rusqlite::params_from_iter;
-use serde_json::json;
 
 use crate::{
     embed::encoder::Encoder,
@@ -27,10 +27,22 @@ async fn hello(pool: Data<Pool<SqliteConnectionManager>>) -> impl Responder {
         .query_row([], |row| row.get(0))
         .unwrap();
 
-    HttpResponse::Ok().json(json!({
-        "total": count_functions,
-        "analyzed": count_analyses
-    }))
+    let mut samples: Vec<String> = vec![];
+    for row in conn
+        .prepare_cached("SELECT summary FROM function_analyses ORDER BY RANDOM() LIMIT 10;")
+        .unwrap()
+        .query_map([], |row| row.get::<usize, String>(0))
+        .unwrap()
+        .map(|row| row.unwrap())
+    {
+        samples.push(row);
+    }
+
+    HttpResponse::Ok().json(types::ResponseHello {
+        total: count_functions,
+        analyzed: count_analyses,
+        samples,
+    })
 }
 
 #[get("/q/{query}")]
@@ -45,20 +57,45 @@ async fn search(
     let q = path.to_string();
     let embeddings = enc.encode(&q);
     let doc_ids = vecdb.send(vector_db::Query(embeddings)).await.unwrap();
-    let mut docs: Vec<String> = Vec::new();
-    for doc in conn
-        .prepare_cached(&format!(
-            "SELECT id, name, signature FROM functions WHERE id IN ({});",
-            repeat_vars(doc_ids.len())
-        ))
-        .unwrap()
-        .query_map(params_from_iter(doc_ids.iter()), |row| {
-            row.get::<usize, String>(2)
+    let mut docs: Vec<types::DocumentSummary> = Vec::new();
+    conn.prepare_cached(&format!(
+        r#"
+            SELECT src.id, name, signature, file, line_start, line_end, tgt.summary 
+            FROM functions as src 
+                LEFT JOIN function_analyses as tgt 
+                ON src.id = tgt.function_id 
+            WHERE src.id IN ({});
+            "#,
+        repeat_vars(doc_ids.len())
+    ))
+    .expect("fail to query on database")
+    .query_map(params_from_iter(doc_ids.iter().map(|x| x.0)), |row| {
+        Ok(types::DocumentSummary {
+            id: row.get(0).unwrap(),
+            name: row.get(1).unwrap(),
+            signature: row.get(2).unwrap(),
+            file: row.get(3).unwrap(),
+            line: types::LineInfo {
+                start: row.get(4).unwrap(),
+                end: row.get(5).unwrap(),
+            },
+            summary: row.get(6).unwrap(),
         })
-        .unwrap()
-    {
+    })
+    .expect("unexpected result from database")
+    .for_each(|doc| {
         docs.push(doc.unwrap());
-    }
+    });
 
-    HttpResponse::Ok().json(json!({"result": doc_ids, "test": docs}))
+    HttpResponse::Ok().json(types::ResponseSearch {
+        query: q,
+        results: doc_ids
+            .iter()
+            .map(|x| types::SearchResult {
+                id: x.0,
+                score: x.1,
+            })
+            .collect(),
+        docs,
+    })
 }

+ 1 - 0
searcher/src/server/mod.rs

@@ -1,2 +1,3 @@
 pub mod handler;
+mod types;
 mod utils;

+ 37 - 0
searcher/src/server/types.rs

@@ -0,0 +1,37 @@
+use serde::Serialize;
+
+#[derive(Serialize)]
+pub struct ResponseHello {
+    pub total: u64,
+    pub analyzed: u64,
+    pub samples: Vec<String>,
+}
+
+#[derive(Serialize)]
+pub struct ResponseSearch {
+    pub query: String,
+    pub docs: Vec<DocumentSummary>,
+    pub results: Vec<SearchResult>,
+}
+
+#[derive(Serialize)]
+pub struct SearchResult {
+    pub id: u64,
+    pub score: f32,
+}
+
+#[derive(Serialize)]
+pub struct DocumentSummary {
+    pub id: u64,
+    pub name: String,
+    pub signature: String,
+    pub file: String,
+    pub line: LineInfo,
+    pub summary: Option<String>,
+}
+
+#[derive(Serialize)]
+pub struct LineInfo {
+    pub start: i32,
+    pub end: i32,
+}

+ 5 - 5
searcher/src/vector_db/mod.rs

@@ -3,7 +3,7 @@ use cxx::UniquePtr;
 use usearch::ffi::{new_index, Index, IndexOptions, MetricKind, ScalarKind};
 
 #[derive(Message)]
-#[rtype(result = "Vec<i32>")]
+#[rtype(result = "Vec<(u64, f32)>")]
 pub struct Query(pub Vec<f32>);
 
 pub struct VectorDB {
@@ -37,14 +37,14 @@ impl Actor for VectorDB {
 }
 
 impl Handler<Query> for VectorDB {
-    type Result = Vec<i32>;
+    type Result = Vec<(u64, f32)>;
 
     fn handle(&mut self, msg: Query, _ctx: &mut Context<Self>) -> Self::Result {
         let result = self.index.search(&msg.0, 10).unwrap();
-        let mut converted: Vec<i32> = Vec::new();
+        let mut converted: Vec<(u64, f32)> = Vec::new();
 
-        for val in result.keys.iter() {
-            converted.push(i32::try_from(*val).unwrap())
+        for (i, val) in result.keys.iter().enumerate() {
+            converted.push((*val, *result.distances.get(i).unwrap()))
         }
 
         converted