ソースを参照

feat(searcher): use usearch for vector serach

iwanhae 1 年間 前
コミット
b26c072e4b

+ 3 - 0
searcher/Cargo.toml

@@ -20,6 +20,9 @@ intel-mkl-src = { version = "0.8.1", features = [
     "mkl-static-lp64-iomp",
 ], optional = true }
 diesel = { version = "2.1.0", features = ["sqlite", "r2d2"] }
+usearch = "1.1.1"
+actix = "0.13.0"
+cxx = "1.0.106"
 
 
 [features]

+ 5 - 2
searcher/src/database/models.rs

@@ -1,10 +1,12 @@
 use super::schema;
 use diesel::prelude::*;
+use serde::{Deserialize, Serialize};
 
 #[derive(Queryable, Selectable)]
 #[diesel(table_name = schema::function_analyses)]
 #[diesel(check_for_backend(diesel::sqlite::Sqlite))]
-pub struct FunctionAnalyses {
+#[derive(Serialize, Deserialize)]
+pub struct FunctionAnalyzed {
     pub function_id: i32,
     pub summary: String,
     pub background: Option<String>,
@@ -17,7 +19,8 @@ pub struct FunctionAnalyses {
 #[derive(Queryable, Selectable)]
 #[diesel(table_name = schema::functions)]
 #[diesel(check_for_backend(diesel::sqlite::Sqlite))]
-pub struct Functions {
+#[derive(Serialize, Deserialize)]
+pub struct FunctionMeta {
     pub id: i32,
     pub name: String,
     pub signature: String,

+ 16 - 7
searcher/src/embed/encoder.rs

@@ -15,7 +15,7 @@ impl Encoder {
         let tokenizer = tokenizer.clone();
         Encoder { model, tokenizer }
     }
-    pub fn encode(&self, prompt: &str) -> Result<Tensor> {
+    pub fn encode(&self, prompt: &str) -> Vec<f32> {
         let start = Instant::now();
         let tokens = self
             .tokenizer
@@ -26,15 +26,24 @@ impl Encoder {
             .unwrap()
             .get_ids()
             .to_vec();
-        let token_ids = Tensor::new(&tokens[..], &self.model.device)?.unsqueeze(0)?;
-        let token_type_ids = token_ids.zeros_like()?;
-        let embeddings = self.model.forward(&token_ids, &token_type_ids)?;
-        let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
+        let token_ids = Tensor::new(&tokens[..], &self.model.device)
+            .unwrap()
+            .unsqueeze(0)
+            .unwrap();
+        let token_type_ids = token_ids.zeros_like().unwrap();
+        let embeddings = self.model.forward(&token_ids, &token_type_ids).unwrap();
+        let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3().unwrap();
 
         // mean pooling
-        let embeddings = (embeddings.sum(1)? / (n_tokens as f64)).unwrap();
+        let embeddings = (embeddings.sum(1).unwrap() / (n_tokens as f64)).unwrap();
         dbg!(prompt, start.elapsed());
-        normalize_l2(&embeddings)?.get(0)
+        let tensor = normalize_l2(&embeddings).unwrap().get(0).unwrap();
+
+        let mut result: Vec<f32> = Vec::new();
+        for i in 0..384 {
+            result.push(tensor.get(i).unwrap().to_scalar::<f32>().unwrap())
+        }
+        result
     }
 }
 

+ 33 - 9
searcher/src/main.rs

@@ -3,7 +3,9 @@ extern crate intel_mkl_src;
 mod args;
 mod database;
 mod embed;
+mod vector_db;
 
+use actix::Addr;
 use actix_web::{get, web, App, HttpResponse, HttpServer, Responder};
 use args::Args;
 use clap::Parser;
@@ -15,13 +17,15 @@ use diesel::{
 use embed::encoder;
 use serde_json::json;
 
+use database::schema::function_analyses::dsl::*;
+use database::schema::functions::dsl::*;
+
+use crate::database::models::FunctionMeta;
+
 #[get("/")]
 async fn hello(pool: web::Data<Pool<ConnectionManager<SqliteConnection>>>) -> impl Responder {
     let conn = &mut pool.get().unwrap();
 
-    use database::schema::function_analyses::dsl::*;
-    use database::schema::functions::dsl::*;
-
     let count_functions = functions
         .count()
         .get_result::<i64>(conn)
@@ -39,10 +43,25 @@ async fn hello(pool: web::Data<Pool<ConnectionManager<SqliteConnection>>>) -> im
 }
 
 #[get("/q/{query}")]
-async fn search(path: web::Path<String>, enc: web::Data<encoder::Encoder>) -> impl Responder {
+async fn search(
+    path: web::Path<String>,
+    enc: web::Data<encoder::Encoder>,
+    vecdb: web::Data<Addr<vector_db::VectorDB>>,
+    pool: web::Data<Pool<ConnectionManager<SqliteConnection>>>,
+) -> impl Responder {
+    let conn = &mut pool.get().unwrap();
+
     let q = path.to_string();
-    let vec = enc.encode(&q).unwrap().to_string();
-    HttpResponse::Ok().body(vec)
+    let embeddings = enc.encode(&q);
+    let doc_ids = vecdb.send(vector_db::Query(embeddings)).await.unwrap();
+
+    let results: Vec<FunctionMeta> = functions
+        .filter(database::schema::functions::dsl::id.eq_any(doc_ids.clone()))
+        .select(database::models::FunctionMeta::as_select())
+        .load(conn)
+        .expect("Error loading posts");
+
+    HttpResponse::Ok().json(json!({"result": doc_ids, "r": results}))
 }
 
 #[actix_web::main]
@@ -57,8 +76,12 @@ async fn main() -> std::io::Result<()> {
     let app_data_encoder = web::Data::new(enc);
 
     // DATABSE
-    let pool = database::establish_connection();
-    let app_data_pool = web::Data::new(pool);
+    let database = database::establish_connection();
+    let app_data_database = web::Data::new(database);
+
+    // VECTOR DATABASE
+    let vector = vector_db::start();
+    let app_data_vector = web::Data::new(vector);
 
     // [END] INIT
     args.terminate_if_ci();
@@ -67,7 +90,8 @@ async fn main() -> std::io::Result<()> {
     let result = HttpServer::new(move || {
         App::new()
             .app_data(app_data_encoder.clone())
-            .app_data(app_data_pool.clone())
+            .app_data(app_data_database.clone())
+            .app_data(app_data_vector.clone())
             .service(hello)
             .service(search)
     })

+ 71 - 0
searcher/src/vector_db/mod.rs

@@ -0,0 +1,71 @@
+use actix::prelude::*;
+use cxx::UniquePtr;
+use usearch::ffi::{new_index, Index, IndexOptions, MetricKind, ScalarKind};
+
+#[derive(Message)]
+#[rtype(result = "Vec<i32>")]
+pub struct Query(pub Vec<f32>);
+
+pub struct VectorDB {
+    index: UniquePtr<Index>,
+}
+
+pub fn start() -> Addr<VectorDB> {
+    let index = new_index(&IndexOptions {
+        metric: MetricKind::Cos,
+        quantization: ScalarKind::F16,
+        dimensions: 384,
+        connectivity: 0,
+        expansion_add: 0,
+        expansion_search: 0,
+    })
+    .unwrap();
+
+    index
+        .load("./kuberian.usearch")
+        .expect("fail to load usearch index");
+
+    VectorDB { index }.start()
+}
+
+impl Actor for VectorDB {
+    type Context = Context<Self>;
+
+    fn started(&mut self, _ctx: &mut Context<Self>) {
+        let index = new_index(&IndexOptions {
+            metric: MetricKind::Cos,
+            quantization: ScalarKind::F16,
+            dimensions: 384,
+            connectivity: 0,
+            expansion_add: 0,
+            expansion_search: 0,
+        })
+        .unwrap();
+
+        index
+            .load("./kuberian.usearch")
+            .expect("fail to load usearch index");
+
+        self.index = index;
+        dbg!("started");
+    }
+
+    fn stopped(&mut self, _ctx: &mut Context<Self>) {
+        dbg!("terminated");
+    }
+}
+
+impl Handler<Query> for VectorDB {
+    type Result = Vec<i32>;
+
+    fn handle(&mut self, msg: Query, _ctx: &mut Context<Self>) -> Self::Result {
+        let result = self.index.search(&msg.0, 10).unwrap();
+        let mut converted: Vec<i32> = Vec::new();
+
+        for val in result.keys.iter() {
+            converted.push(i32::try_from(*val).unwrap())
+        }
+
+        converted
+    }
+}