|
@@ -3,7 +3,9 @@ extern crate intel_mkl_src;
|
|
|
mod args;
|
|
|
mod database;
|
|
|
mod embed;
|
|
|
+mod vector_db;
|
|
|
|
|
|
+use actix::Addr;
|
|
|
use actix_web::{get, web, App, HttpResponse, HttpServer, Responder};
|
|
|
use args::Args;
|
|
|
use clap::Parser;
|
|
@@ -15,13 +17,15 @@ use diesel::{
|
|
|
use embed::encoder;
|
|
|
use serde_json::json;
|
|
|
|
|
|
+use database::schema::function_analyses::dsl::*;
|
|
|
+use database::schema::functions::dsl::*;
|
|
|
+
|
|
|
+use crate::database::models::FunctionMeta;
|
|
|
+
|
|
|
#[get("/")]
|
|
|
async fn hello(pool: web::Data<Pool<ConnectionManager<SqliteConnection>>>) -> impl Responder {
|
|
|
let conn = &mut pool.get().unwrap();
|
|
|
|
|
|
- use database::schema::function_analyses::dsl::*;
|
|
|
- use database::schema::functions::dsl::*;
|
|
|
-
|
|
|
let count_functions = functions
|
|
|
.count()
|
|
|
.get_result::<i64>(conn)
|
|
@@ -39,10 +43,25 @@ async fn hello(pool: web::Data<Pool<ConnectionManager<SqliteConnection>>>) -> im
|
|
|
}
|
|
|
|
|
|
#[get("/q/{query}")]
|
|
|
-async fn search(path: web::Path<String>, enc: web::Data<encoder::Encoder>) -> impl Responder {
|
|
|
+async fn search(
|
|
|
+ path: web::Path<String>,
|
|
|
+ enc: web::Data<encoder::Encoder>,
|
|
|
+ vecdb: web::Data<Addr<vector_db::VectorDB>>,
|
|
|
+ pool: web::Data<Pool<ConnectionManager<SqliteConnection>>>,
|
|
|
+) -> impl Responder {
|
|
|
+ let conn = &mut pool.get().unwrap();
|
|
|
+
|
|
|
let q = path.to_string();
|
|
|
- let vec = enc.encode(&q).unwrap().to_string();
|
|
|
- HttpResponse::Ok().body(vec)
|
|
|
+ let embeddings = enc.encode(&q);
|
|
|
+ let doc_ids = vecdb.send(vector_db::Query(embeddings)).await.unwrap();
|
|
|
+
|
|
|
+ let results: Vec<FunctionMeta> = functions
|
|
|
+ .filter(database::schema::functions::dsl::id.eq_any(doc_ids.clone()))
|
|
|
+ .select(database::models::FunctionMeta::as_select())
|
|
|
+ .load(conn)
|
|
|
+ .expect("Error loading posts");
|
|
|
+
|
|
|
+ HttpResponse::Ok().json(json!({"result": doc_ids, "r": results}))
|
|
|
}
|
|
|
|
|
|
#[actix_web::main]
|
|
@@ -57,8 +76,12 @@ async fn main() -> std::io::Result<()> {
|
|
|
let app_data_encoder = web::Data::new(enc);
|
|
|
|
|
|
// DATABSE
|
|
|
- let pool = database::establish_connection();
|
|
|
- let app_data_pool = web::Data::new(pool);
|
|
|
+ let database = database::establish_connection();
|
|
|
+ let app_data_database = web::Data::new(database);
|
|
|
+
|
|
|
+ // VECTOR DATABASE
|
|
|
+ let vector = vector_db::start();
|
|
|
+ let app_data_vector = web::Data::new(vector);
|
|
|
|
|
|
// [END] INIT
|
|
|
args.terminate_if_ci();
|
|
@@ -67,7 +90,8 @@ async fn main() -> std::io::Result<()> {
|
|
|
let result = HttpServer::new(move || {
|
|
|
App::new()
|
|
|
.app_data(app_data_encoder.clone())
|
|
|
- .app_data(app_data_pool.clone())
|
|
|
+ .app_data(app_data_database.clone())
|
|
|
+ .app_data(app_data_vector.clone())
|
|
|
.service(hello)
|
|
|
.service(search)
|
|
|
})
|