|
@@ -9,32 +9,27 @@ use actix::Addr;
|
|
|
use actix_web::{get, web, App, HttpResponse, HttpServer, Responder};
|
|
|
use args::Args;
|
|
|
use clap::Parser;
|
|
|
-use diesel::prelude::*;
|
|
|
-use diesel::{
|
|
|
- r2d2::{ConnectionManager, Pool},
|
|
|
- SqliteConnection,
|
|
|
-};
|
|
|
use embed::encoder;
|
|
|
+use r2d2::Pool;
|
|
|
+use r2d2_sqlite::SqliteConnectionManager;
|
|
|
+use rusqlite::params_from_iter;
|
|
|
use serde_json::json;
|
|
|
|
|
|
-use database::schema::function_analyses::dsl::*;
|
|
|
-use database::schema::functions::dsl::*;
|
|
|
-
|
|
|
-use crate::database::models::FunctionMeta;
|
|
|
-
|
|
|
#[get("/")]
|
|
|
-async fn hello(pool: web::Data<Pool<ConnectionManager<SqliteConnection>>>) -> impl Responder {
|
|
|
- let conn = &mut pool.get().unwrap();
|
|
|
+async fn hello(pool: web::Data<Pool<SqliteConnectionManager>>) -> impl Responder {
|
|
|
+ let conn = pool.get().unwrap();
|
|
|
|
|
|
- let count_functions = functions
|
|
|
- .count()
|
|
|
- .get_result::<i64>(conn)
|
|
|
- .expect("can not get functions stats");
|
|
|
+ let count_functions: u64 = conn
|
|
|
+ .prepare_cached("SELECT COUNT(*) FROM functions;")
|
|
|
+ .unwrap()
|
|
|
+ .query_row([], |row| row.get(0))
|
|
|
+ .unwrap();
|
|
|
|
|
|
- let count_analyses = function_analyses
|
|
|
- .count()
|
|
|
- .get_result::<i64>(conn)
|
|
|
- .expect("can not get function_analyses stats");
|
|
|
+ let count_analyses: u64 = conn
|
|
|
+ .prepare_cached("SELECT COUNT(*) FROM function_analyses;")
|
|
|
+ .unwrap()
|
|
|
+ .query_row([], |row| row.get(0))
|
|
|
+ .unwrap();
|
|
|
|
|
|
HttpResponse::Ok().json(json!({
|
|
|
"total": count_functions,
|
|
@@ -47,21 +42,29 @@ async fn search(
|
|
|
path: web::Path<String>,
|
|
|
enc: web::Data<encoder::Encoder>,
|
|
|
vecdb: web::Data<Addr<vector_db::VectorDB>>,
|
|
|
- pool: web::Data<Pool<ConnectionManager<SqliteConnection>>>,
|
|
|
+ pool: web::Data<Pool<SqliteConnectionManager>>,
|
|
|
) -> impl Responder {
|
|
|
- let conn = &mut pool.get().unwrap();
|
|
|
+ let conn = pool.get().unwrap();
|
|
|
|
|
|
let q = path.to_string();
|
|
|
let embeddings = enc.encode(&q);
|
|
|
let doc_ids = vecdb.send(vector_db::Query(embeddings)).await.unwrap();
|
|
|
-
|
|
|
- let results: Vec<FunctionMeta> = functions
|
|
|
- .filter(database::schema::functions::dsl::id.eq_any(doc_ids.clone()))
|
|
|
- .select(database::models::FunctionMeta::as_select())
|
|
|
- .load(conn)
|
|
|
- .expect("Error loading posts");
|
|
|
-
|
|
|
- HttpResponse::Ok().json(json!({"result": doc_ids, "r": results}))
|
|
|
+ let mut docs: Vec<String> = Vec::new();
|
|
|
+ for doc in conn
|
|
|
+ .prepare_cached(&format!(
|
|
|
+ "SELECT id, name, signature FROM functions WHERE id IN ({});",
|
|
|
+ repeat_vars(doc_ids.len())
|
|
|
+ ))
|
|
|
+ .unwrap()
|
|
|
+ .query_map(params_from_iter(doc_ids.iter()), |row| {
|
|
|
+ row.get::<usize, String>(2)
|
|
|
+ })
|
|
|
+ .unwrap()
|
|
|
+ {
|
|
|
+ docs.push(doc.unwrap());
|
|
|
+ }
|
|
|
+
|
|
|
+ HttpResponse::Ok().json(json!({"result": doc_ids, "test": docs}))
|
|
|
}
|
|
|
|
|
|
#[actix_web::main]
|
|
@@ -76,7 +79,7 @@ async fn main() -> std::io::Result<()> {
|
|
|
let app_data_encoder = web::Data::new(enc);
|
|
|
|
|
|
// DATABSE
|
|
|
- let database = database::establish_connection();
|
|
|
+ let database = database::new_connection_pool();
|
|
|
let app_data_database = web::Data::new(database);
|
|
|
|
|
|
// VECTOR DATABASE
|
|
@@ -101,3 +104,11 @@ async fn main() -> std::io::Result<()> {
|
|
|
println!("Server Terminated. Byebye :-)");
|
|
|
result
|
|
|
}
|
|
|
+
|
|
|
+fn repeat_vars(count: usize) -> String {
|
|
|
+ assert_ne!(count, 0);
|
|
|
+ let mut s = "?,".repeat(count);
|
|
|
+ // Remove trailing comma
|
|
|
+ s.pop();
|
|
|
+ s
|
|
|
+}
|