Browse Source

feat(searcher): use r2d2 instaed of diesel

iwanhae 1 year ago
parent
commit
2cd0a9849e

+ 3 - 1
searcher/Cargo.toml

@@ -19,7 +19,9 @@ actix-web = "4"
 intel-mkl-src = { version = "0.8.1", features = [
     "mkl-static-lp64-iomp",
 ], optional = true }
-diesel = { version = "2.1.0", features = ["sqlite", "r2d2"] }
+r2d2 = "0.8.10"
+r2d2_sqlite = "0.22.0"
+rusqlite = { version = "0.29.0", features = ["bundled", "trace", "array"] }
 usearch = "1.1.1"
 actix = "0.13.0"
 cxx = "1.0.106"

+ 8 - 11
searcher/src/database/mod.rs

@@ -1,15 +1,12 @@
-pub mod models;
-pub mod schema;
-
-use diesel::r2d2::ConnectionManager;
-use diesel::r2d2::Pool;
-use diesel::sqlite::SqliteConnection;
 use std::env;
 
-pub fn establish_connection() -> Pool<ConnectionManager<SqliteConnection>> {
+use r2d2::Pool;
+use r2d2_sqlite::SqliteConnectionManager;
+
+pub fn new_connection_pool() -> Pool<SqliteConnectionManager> {
     let database_url = env::var("DATABASE_URL").unwrap_or_else(|_| String::from("./kuberian.db"));
-    let manger = ConnectionManager::<SqliteConnection>::new(database_url);
-    Pool::builder()
-        .build(manger)
-        .expect("Could not build connection pool")
+    let manager = SqliteConnectionManager::file(database_url);
+    let pool = Pool::new(manager).expect("fail to generate db connection pool");
+
+    pool
 }

+ 0 - 31
searcher/src/database/models.rs

@@ -1,31 +0,0 @@
-use super::schema;
-use diesel::prelude::*;
-use serde::{Deserialize, Serialize};
-
-#[derive(Queryable, Selectable)]
-#[diesel(table_name = schema::function_analyses)]
-#[diesel(check_for_backend(diesel::sqlite::Sqlite))]
-#[derive(Serialize, Deserialize)]
-pub struct FunctionAnalyzed {
-    pub function_id: i32,
-    pub summary: String,
-    pub background: Option<String>,
-    pub analysis: Option<String>,
-    pub purpose: Option<String>,
-    pub comment: Option<String>,
-    pub tldr: Option<String>,
-}
-
-#[derive(Queryable, Selectable)]
-#[diesel(table_name = schema::functions)]
-#[diesel(check_for_backend(diesel::sqlite::Sqlite))]
-#[derive(Serialize, Deserialize)]
-pub struct FunctionMeta {
-    pub id: i32,
-    pub name: String,
-    pub signature: String,
-    pub file: String,
-    pub code: String,
-    pub line_start: i32,
-    pub line_end: i32,
-}

+ 0 - 28
searcher/src/database/schema.rs

@@ -1,28 +0,0 @@
-// @generated automatically by Diesel CLI.
-
-diesel::table! {
-    function_analyses (id) {
-        id -> Integer,
-        function_id -> Integer,
-        summary -> Text,
-        background -> Nullable<Text>,
-        analysis -> Nullable<Text>,
-        purpose -> Nullable<Text>,
-        comment -> Nullable<Text>,
-        tldr -> Nullable<Text>,
-    }
-}
-
-diesel::table! {
-    functions (id) {
-        id -> Integer,
-        name -> Text,
-        signature -> Text,
-        file -> Text,
-        code -> Text,
-        line_start -> Integer,
-        line_end -> Integer,
-    }
-}
-
-diesel::allow_tables_to_appear_in_same_query!(function_analyses, functions,);

+ 42 - 31
searcher/src/main.rs

@@ -9,32 +9,27 @@ use actix::Addr;
 use actix_web::{get, web, App, HttpResponse, HttpServer, Responder};
 use args::Args;
 use clap::Parser;
-use diesel::prelude::*;
-use diesel::{
-    r2d2::{ConnectionManager, Pool},
-    SqliteConnection,
-};
 use embed::encoder;
+use r2d2::Pool;
+use r2d2_sqlite::SqliteConnectionManager;
+use rusqlite::params_from_iter;
 use serde_json::json;
 
-use database::schema::function_analyses::dsl::*;
-use database::schema::functions::dsl::*;
-
-use crate::database::models::FunctionMeta;
-
 #[get("/")]
-async fn hello(pool: web::Data<Pool<ConnectionManager<SqliteConnection>>>) -> impl Responder {
-    let conn = &mut pool.get().unwrap();
+async fn hello(pool: web::Data<Pool<SqliteConnectionManager>>) -> impl Responder {
+    let conn = pool.get().unwrap();
 
-    let count_functions = functions
-        .count()
-        .get_result::<i64>(conn)
-        .expect("can not get functions stats");
+    let count_functions: u64 = conn
+        .prepare_cached("SELECT COUNT(*) FROM functions;")
+        .unwrap()
+        .query_row([], |row| row.get(0))
+        .unwrap();
 
-    let count_analyses = function_analyses
-        .count()
-        .get_result::<i64>(conn)
-        .expect("can not get function_analyses stats");
+    let count_analyses: u64 = conn
+        .prepare_cached("SELECT COUNT(*) FROM function_analyses;")
+        .unwrap()
+        .query_row([], |row| row.get(0))
+        .unwrap();
 
     HttpResponse::Ok().json(json!({
         "total": count_functions,
@@ -47,21 +42,29 @@ async fn search(
     path: web::Path<String>,
     enc: web::Data<encoder::Encoder>,
     vecdb: web::Data<Addr<vector_db::VectorDB>>,
-    pool: web::Data<Pool<ConnectionManager<SqliteConnection>>>,
+    pool: web::Data<Pool<SqliteConnectionManager>>,
 ) -> impl Responder {
-    let conn = &mut pool.get().unwrap();
+    let conn = pool.get().unwrap();
 
     let q = path.to_string();
     let embeddings = enc.encode(&q);
     let doc_ids = vecdb.send(vector_db::Query(embeddings)).await.unwrap();
-
-    let results: Vec<FunctionMeta> = functions
-        .filter(database::schema::functions::dsl::id.eq_any(doc_ids.clone()))
-        .select(database::models::FunctionMeta::as_select())
-        .load(conn)
-        .expect("Error loading posts");
-
-    HttpResponse::Ok().json(json!({"result": doc_ids, "r": results}))
+    let mut docs: Vec<String> = Vec::new();
+    for doc in conn
+        .prepare_cached(&format!(
+            "SELECT id, name, signature FROM functions WHERE id IN ({});",
+            repeat_vars(doc_ids.len())
+        ))
+        .unwrap()
+        .query_map(params_from_iter(doc_ids.iter()), |row| {
+            row.get::<usize, String>(2)
+        })
+        .unwrap()
+    {
+        docs.push(doc.unwrap());
+    }
+
+    HttpResponse::Ok().json(json!({"result": doc_ids, "test": docs}))
 }
 
 #[actix_web::main]
@@ -76,7 +79,7 @@ async fn main() -> std::io::Result<()> {
     let app_data_encoder = web::Data::new(enc);
 
     // DATABSE
-    let database = database::establish_connection();
+    let database = database::new_connection_pool();
     let app_data_database = web::Data::new(database);
 
     // VECTOR DATABASE
@@ -101,3 +104,11 @@ async fn main() -> std::io::Result<()> {
     println!("Server Terminated. Byebye :-)");
     result
 }
+
+fn repeat_vars(count: usize) -> String {
+    assert_ne!(count, 0);
+    let mut s = "?,".repeat(count);
+    // Remove trailing comma
+    s.pop();
+    s
+}