handler.rs 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. use std::sync::atomic::{AtomicU64, Ordering};
  2. use super::types;
  3. use actix::Addr;
  4. use actix_web::{get, web::Data, web::Path, HttpResponse, Responder};
  5. use r2d2::Pool;
  6. use r2d2_sqlite::SqliteConnectionManager;
  7. use rusqlite::params_from_iter;
  8. use crate::{
  9. embed::encoder::Encoder,
  10. server::utils::repeat_vars,
  11. vector_db::{self, VectorDB},
  12. };
  13. #[get("/healthz")]
  14. async fn healthz() -> impl Responder {
  15. HttpResponse::Ok().body("ok")
  16. }
  17. #[get("/")]
  18. async fn hello(pool: Data<Pool<SqliteConnectionManager>>) -> impl Responder {
  19. let conn = pool.get().unwrap();
  20. let count_functions: u64 = conn
  21. .prepare_cached("SELECT COUNT(*) FROM functions;")
  22. .unwrap()
  23. .query_row([], |row| row.get(0))
  24. .unwrap();
  25. let count_analyses: u64 = conn
  26. .prepare_cached("SELECT COUNT(*) FROM function_analyses;")
  27. .unwrap()
  28. .query_row([], |row| row.get(0))
  29. .unwrap();
  30. let mut samples: Vec<String> = vec![];
  31. for row in conn
  32. .prepare_cached(
  33. r"SELECT summary FROM function_analyses WHERE id
  34. IN (SELECT id FROM function_analyses ORDER BY RANDOM() LIMIT 10);",
  35. )
  36. .unwrap()
  37. .query_map([], |row| row.get::<usize, String>(0))
  38. .unwrap()
  39. .map(|row| row.unwrap())
  40. {
  41. samples.push(row);
  42. }
  43. HttpResponse::Ok().json(types::ResponseHello {
  44. total: count_functions,
  45. analyzed: count_analyses,
  46. samples,
  47. })
  48. }
  49. #[get("/q/{query}")]
  50. async fn search(
  51. path: Path<String>,
  52. enc: Data<Encoder>,
  53. vecdb: Data<Addr<VectorDB>>,
  54. pool: Data<Pool<SqliteConnectionManager>>,
  55. ) -> impl Responder {
  56. let conn = pool.get().unwrap();
  57. let q = path.to_string();
  58. let embeddings = enc.encode(&q);
  59. let doc_ids = vecdb.send(vector_db::Query(embeddings)).await.unwrap();
  60. let mut docs: Vec<types::DocumentSummary> = Vec::new();
  61. conn.prepare_cached(&format!(
  62. r#"
  63. SELECT src.id, name, signature, file, line_start, line_end, tgt.summary
  64. FROM functions as src
  65. LEFT JOIN function_analyses as tgt
  66. ON src.id = tgt.function_id
  67. WHERE src.id IN ({});
  68. "#,
  69. repeat_vars(doc_ids.len())
  70. ))
  71. .expect("fail to query on database")
  72. .query_map(params_from_iter(doc_ids.iter().map(|x| x.0)), |row| {
  73. Ok(types::DocumentSummary {
  74. id: row.get(0).unwrap(),
  75. name: row.get(1).unwrap(),
  76. signature: row.get(2).unwrap(),
  77. file: row.get(3).unwrap(),
  78. line: types::LineInfo {
  79. start: row.get(4).unwrap(),
  80. end: row.get(5).unwrap(),
  81. },
  82. summary: row.get(6).unwrap(),
  83. })
  84. })
  85. .expect("unexpected result from database")
  86. .for_each(|doc| {
  87. docs.push(doc.unwrap());
  88. });
  89. HttpResponse::Ok().json(types::ResponseSearch {
  90. query: q,
  91. results: doc_ids
  92. .iter()
  93. .map(|x| types::SearchResult {
  94. id: x.0,
  95. score: x.1,
  96. })
  97. .collect(),
  98. docs,
  99. })
  100. }