瀏覽代碼

feat(searcher): seperates http handler module

iwanhae 1 年之前
父節點
當前提交
3a6c3e658c
共有 4 個文件被更改,包括 77 次插入68 次删除
  1. 4 68
      searcher/src/main.rs
  2. 64 0
      searcher/src/server/handler.rs
  3. 2 0
      searcher/src/server/mod.rs
  4. 7 0
      searcher/src/server/utils.rs

+ 4 - 68
searcher/src/main.rs

@@ -3,69 +3,13 @@ extern crate intel_mkl_src;
 mod args;
 mod database;
 mod embed;
+mod server;
 mod vector_db;
 
-use actix::Addr;
-use actix_web::{get, web, App, HttpResponse, HttpServer, Responder};
+use actix_web::{web, App, HttpServer};
 use args::Args;
 use clap::Parser;
 use embed::encoder;
-use r2d2::Pool;
-use r2d2_sqlite::SqliteConnectionManager;
-use rusqlite::params_from_iter;
-use serde_json::json;
-
-#[get("/")]
-async fn hello(pool: web::Data<Pool<SqliteConnectionManager>>) -> impl Responder {
-    let conn = pool.get().unwrap();
-
-    let count_functions: u64 = conn
-        .prepare_cached("SELECT COUNT(*) FROM functions;")
-        .unwrap()
-        .query_row([], |row| row.get(0))
-        .unwrap();
-
-    let count_analyses: u64 = conn
-        .prepare_cached("SELECT COUNT(*) FROM function_analyses;")
-        .unwrap()
-        .query_row([], |row| row.get(0))
-        .unwrap();
-
-    HttpResponse::Ok().json(json!({
-        "total": count_functions,
-        "analyzed": count_analyses
-    }))
-}
-
-#[get("/q/{query}")]
-async fn search(
-    path: web::Path<String>,
-    enc: web::Data<encoder::Encoder>,
-    vecdb: web::Data<Addr<vector_db::VectorDB>>,
-    pool: web::Data<Pool<SqliteConnectionManager>>,
-) -> impl Responder {
-    let conn = pool.get().unwrap();
-
-    let q = path.to_string();
-    let embeddings = enc.encode(&q);
-    let doc_ids = vecdb.send(vector_db::Query(embeddings)).await.unwrap();
-    let mut docs: Vec<String> = Vec::new();
-    for doc in conn
-        .prepare_cached(&format!(
-            "SELECT id, name, signature FROM functions WHERE id IN ({});",
-            repeat_vars(doc_ids.len())
-        ))
-        .unwrap()
-        .query_map(params_from_iter(doc_ids.iter()), |row| {
-            row.get::<usize, String>(2)
-        })
-        .unwrap()
-    {
-        docs.push(doc.unwrap());
-    }
-
-    HttpResponse::Ok().json(json!({"result": doc_ids, "test": docs}))
-}
 
 #[actix_web::main]
 async fn main() -> std::io::Result<()> {
@@ -95,8 +39,8 @@ async fn main() -> std::io::Result<()> {
             .app_data(app_data_encoder.clone())
             .app_data(app_data_database.clone())
             .app_data(app_data_vector.clone())
-            .service(hello)
-            .service(search)
+            .service(server::handler::hello)
+            .service(server::handler::search)
     })
     .bind(("0.0.0.0", 8080))?
     .run()
@@ -104,11 +48,3 @@ async fn main() -> std::io::Result<()> {
     println!("Server Terminated. Byebye :-)");
     result
 }
-
-fn repeat_vars(count: usize) -> String {
-    assert_ne!(count, 0);
-    let mut s = "?,".repeat(count);
-    // Remove trailing comma
-    s.pop();
-    s
-}

+ 64 - 0
searcher/src/server/handler.rs

@@ -0,0 +1,64 @@
+use actix::Addr;
+use actix_web::{get, web::Data, web::Path, HttpResponse, Responder};
+use r2d2::Pool;
+use r2d2_sqlite::SqliteConnectionManager;
+use rusqlite::params_from_iter;
+use serde_json::json;
+
+use crate::{
+    embed::encoder::Encoder,
+    server::utils::repeat_vars,
+    vector_db::{self, VectorDB},
+};
+
+#[get("/")]
+async fn hello(pool: Data<Pool<SqliteConnectionManager>>) -> impl Responder {
+    let conn = pool.get().unwrap();
+
+    let count_functions: u64 = conn
+        .prepare_cached("SELECT COUNT(*) FROM functions;")
+        .unwrap()
+        .query_row([], |row| row.get(0))
+        .unwrap();
+
+    let count_analyses: u64 = conn
+        .prepare_cached("SELECT COUNT(*) FROM function_analyses;")
+        .unwrap()
+        .query_row([], |row| row.get(0))
+        .unwrap();
+
+    HttpResponse::Ok().json(json!({
+        "total": count_functions,
+        "analyzed": count_analyses
+    }))
+}
+
+#[get("/q/{query}")]
+async fn search(
+    path: Path<String>,
+    enc: Data<Encoder>,
+    vecdb: Data<Addr<VectorDB>>,
+    pool: Data<Pool<SqliteConnectionManager>>,
+) -> impl Responder {
+    let conn = pool.get().unwrap();
+
+    let q = path.to_string();
+    let embeddings = enc.encode(&q);
+    let doc_ids = vecdb.send(vector_db::Query(embeddings)).await.unwrap();
+    let mut docs: Vec<String> = Vec::new();
+    for doc in conn
+        .prepare_cached(&format!(
+            "SELECT id, name, signature FROM functions WHERE id IN ({});",
+            repeat_vars(doc_ids.len())
+        ))
+        .unwrap()
+        .query_map(params_from_iter(doc_ids.iter()), |row| {
+            row.get::<usize, String>(2)
+        })
+        .unwrap()
+    {
+        docs.push(doc.unwrap());
+    }
+
+    HttpResponse::Ok().json(json!({"result": doc_ids, "test": docs}))
+}

+ 2 - 0
searcher/src/server/mod.rs

@@ -0,0 +1,2 @@
+pub mod handler;
+mod utils;

+ 7 - 0
searcher/src/server/utils.rs

@@ -0,0 +1,7 @@
+pub fn repeat_vars(count: usize) -> String {
+    assert_ne!(count, 0);
+    let mut s = "?,".repeat(count);
+    // Remove trailing comma
+    s.pop();
+    s
+}