args.rs 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. use crate::embed;
  2. use anyhow::{anyhow, Error as E, Result};
  3. use candle::Device;
  4. use candle_nn::VarBuilder;
  5. use clap::Parser;
  6. use embed::model::{BertModel, Config, DTYPE};
  7. use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
  8. use tokenizers::Tokenizer;
  9. #[derive(Parser, Debug)]
  10. #[command(author, version, about, long_about = None)]
  11. pub struct Args {
  12. /// teminate immediately (just for downloading BERT model)
  13. #[arg(long)]
  14. ci: bool,
  15. /// Run offline (you must have the files already cached)
  16. #[arg(long)]
  17. offline: bool,
  18. /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
  19. #[arg(long)]
  20. model_id: Option<String>,
  21. #[arg(long)]
  22. revision: Option<String>,
  23. }
  24. impl Args {
  25. pub fn terminate_if_ci(&self) {
  26. if self.ci {
  27. info!("terminating ci mode");
  28. std::process::exit(0)
  29. }
  30. }
  31. pub fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
  32. let device = Device::Cpu;
  33. let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
  34. // source: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/discussions/21
  35. let default_revision = "refs/pr/21".to_string();
  36. let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
  37. (Some(model_id), Some(revision)) => (model_id, revision),
  38. (Some(model_id), None) => (model_id, "main".to_string()),
  39. (None, Some(revision)) => (default_model, revision),
  40. (None, None) => (default_model, default_revision),
  41. };
  42. let repo = Repo::with_revision(model_id, RepoType::Model, revision);
  43. let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
  44. let cache = Cache::default();
  45. (
  46. cache
  47. .get(&repo, "config.json")
  48. .ok_or(anyhow!("Missing config file in cache"))?,
  49. cache
  50. .get(&repo, "tokenizer.json")
  51. .ok_or(anyhow!("Missing tokenizer file in cache"))?,
  52. cache
  53. .get(&repo, "model.safetensors")
  54. .ok_or(anyhow!("Missing weights file in cache"))?,
  55. )
  56. } else {
  57. let api = Api::new()?;
  58. let api = api.repo(repo);
  59. (
  60. api.get("config.json")?,
  61. api.get("tokenizer.json")?,
  62. api.get("model.safetensors")?,
  63. )
  64. };
  65. let config = std::fs::read_to_string(config_filename)?;
  66. let config: Config = serde_json::from_str(&config)?;
  67. let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
  68. let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
  69. let weights = weights.deserialize()?;
  70. let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
  71. let model = BertModel::load(vb, &config)?;
  72. Ok((model, tokenizer))
  73. }
  74. }