model.rs 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. use candle_core::{DType, Device, Result, Tensor};
  2. use candle_nn::{Embedding, VarBuilder};
  3. use serde::Deserialize;
  4. pub const DTYPE: DType = DType::F32;
  5. #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
  6. #[serde(rename_all = "lowercase")]
  7. enum HiddenAct {
  8. Gelu,
  9. Relu,
  10. }
  11. struct HiddenActLayer {
  12. act: HiddenAct,
  13. span: tracing::Span,
  14. }
  15. impl HiddenActLayer {
  16. fn new(act: HiddenAct) -> Self {
  17. let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
  18. Self { act, span }
  19. }
  20. fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> {
  21. let _enter = self.span.enter();
  22. match self.act {
  23. // TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some
  24. // small numerical difference.
  25. // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
  26. HiddenAct::Gelu => xs.gelu(),
  27. HiddenAct::Relu => xs.relu(),
  28. }
  29. }
  30. }
  31. #[derive(Debug)]
  32. pub struct Linear {
  33. weight: Tensor,
  34. bias: Option<Tensor>,
  35. span: tracing::Span,
  36. }
  37. impl Linear {
  38. pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
  39. let span = tracing::span!(tracing::Level::TRACE, "linear");
  40. Self { weight, bias, span }
  41. }
  42. pub fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
  43. let _enter = self.span.enter();
  44. let w = match x.dims() {
  45. &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
  46. _ => self.weight.t()?,
  47. };
  48. let x = x.matmul(&w)?;
  49. match &self.bias {
  50. None => Ok(x),
  51. Some(bias) => x.broadcast_add(bias),
  52. }
  53. }
  54. }
  55. #[derive(Debug)]
  56. pub struct LayerNorm {
  57. weight: Tensor,
  58. bias: Tensor,
  59. eps: f64,
  60. span: tracing::Span,
  61. }
  62. impl LayerNorm {
  63. pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
  64. let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
  65. Self {
  66. weight,
  67. bias,
  68. eps,
  69. span,
  70. }
  71. }
  72. pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
  73. let _enter = self.span.enter();
  74. let x_dtype = x.dtype();
  75. let internal_dtype = match x_dtype {
  76. DType::F16 | DType::BF16 => DType::F32,
  77. d => d,
  78. };
  79. let (_bsize, _seq_len, hidden_size) = x.dims3()?;
  80. let x = x.to_dtype(internal_dtype)?;
  81. let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
  82. let x = x.broadcast_sub(&mean_x)?;
  83. let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
  84. let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
  85. let x = x_normed
  86. .to_dtype(x_dtype)?
  87. .broadcast_mul(&self.weight)?
  88. .broadcast_add(&self.bias)?;
  89. Ok(x)
  90. }
  91. }
  92. #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
  93. #[serde(rename_all = "lowercase")]
  94. enum PositionEmbeddingType {
  95. #[default]
  96. Absolute,
  97. }
  98. // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
  99. #[derive(Debug, Clone, PartialEq, Deserialize)]
  100. pub struct Config {
  101. vocab_size: usize,
  102. hidden_size: usize,
  103. num_hidden_layers: usize,
  104. num_attention_heads: usize,
  105. intermediate_size: usize,
  106. hidden_act: HiddenAct,
  107. hidden_dropout_prob: f64,
  108. max_position_embeddings: usize,
  109. type_vocab_size: usize,
  110. initializer_range: f64,
  111. layer_norm_eps: f64,
  112. pad_token_id: usize,
  113. #[serde(default)]
  114. position_embedding_type: PositionEmbeddingType,
  115. #[serde(default)]
  116. use_cache: bool,
  117. classifier_dropout: Option<f64>,
  118. model_type: Option<String>,
  119. }
  120. impl Default for Config {
  121. fn default() -> Self {
  122. Self {
  123. vocab_size: 30522,
  124. hidden_size: 768,
  125. num_hidden_layers: 12,
  126. num_attention_heads: 12,
  127. intermediate_size: 3072,
  128. hidden_act: HiddenAct::Gelu,
  129. hidden_dropout_prob: 0.1,
  130. max_position_embeddings: 512,
  131. type_vocab_size: 2,
  132. initializer_range: 0.02,
  133. layer_norm_eps: 1e-12,
  134. pad_token_id: 0,
  135. position_embedding_type: PositionEmbeddingType::Absolute,
  136. use_cache: true,
  137. classifier_dropout: None,
  138. model_type: Some("bert".to_string()),
  139. }
  140. }
  141. }
  142. impl Config {
  143. fn _all_mini_lm_l6_v2() -> Self {
  144. // https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json
  145. Self {
  146. vocab_size: 30522,
  147. hidden_size: 384,
  148. num_hidden_layers: 6,
  149. num_attention_heads: 12,
  150. intermediate_size: 1536,
  151. hidden_act: HiddenAct::Gelu,
  152. hidden_dropout_prob: 0.1,
  153. max_position_embeddings: 512,
  154. type_vocab_size: 2,
  155. initializer_range: 0.02,
  156. layer_norm_eps: 1e-12,
  157. pad_token_id: 0,
  158. position_embedding_type: PositionEmbeddingType::Absolute,
  159. use_cache: true,
  160. classifier_dropout: None,
  161. model_type: Some("bert".to_string()),
  162. }
  163. }
  164. }
  165. fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
  166. let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
  167. Ok(Embedding::new(embeddings, hidden_size))
  168. }
  169. fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
  170. let weight = vb.get((size2, size1), "weight")?;
  171. let bias = vb.get(size2, "bias")?;
  172. Ok(Linear::new(weight, Some(bias)))
  173. }
  174. struct Dropout {
  175. #[allow(dead_code)]
  176. pr: f64,
  177. }
  178. impl Dropout {
  179. fn new(pr: f64) -> Self {
  180. Self { pr }
  181. }
  182. fn forward(&self, x: &Tensor) -> Result<Tensor> {
  183. // TODO
  184. Ok(x.clone())
  185. }
  186. }
  187. fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
  188. let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
  189. (Ok(weight), Ok(bias)) => (weight, bias),
  190. (Err(err), _) | (_, Err(err)) => {
  191. if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
  192. (weight, bias)
  193. } else {
  194. return Err(err);
  195. }
  196. }
  197. };
  198. Ok(LayerNorm::new(weight, bias, eps))
  199. }
  200. // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
  201. struct BertEmbeddings {
  202. word_embeddings: Embedding,
  203. position_embeddings: Option<Embedding>,
  204. token_type_embeddings: Embedding,
  205. layer_norm: LayerNorm,
  206. dropout: Dropout,
  207. span: tracing::Span,
  208. }
  209. impl BertEmbeddings {
  210. fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
  211. let word_embeddings = embedding(
  212. config.vocab_size,
  213. config.hidden_size,
  214. vb.pp("word_embeddings"),
  215. )?;
  216. let position_embeddings = embedding(
  217. config.max_position_embeddings,
  218. config.hidden_size,
  219. vb.pp("position_embeddings"),
  220. )?;
  221. let token_type_embeddings = embedding(
  222. config.type_vocab_size,
  223. config.hidden_size,
  224. vb.pp("token_type_embeddings"),
  225. )?;
  226. let layer_norm = layer_norm(
  227. config.hidden_size,
  228. config.layer_norm_eps,
  229. vb.pp("LayerNorm"),
  230. )?;
  231. Ok(Self {
  232. word_embeddings,
  233. position_embeddings: Some(position_embeddings),
  234. token_type_embeddings,
  235. layer_norm,
  236. dropout: Dropout::new(config.hidden_dropout_prob),
  237. span: tracing::span!(tracing::Level::TRACE, "embeddings"),
  238. })
  239. }
  240. fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
  241. let _enter = self.span.enter();
  242. let (_bsize, seq_len) = input_ids.dims2()?;
  243. let input_embeddings = self.word_embeddings.forward(input_ids)?;
  244. let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
  245. let mut embeddings = (&input_embeddings + token_type_embeddings)?;
  246. if let Some(position_embeddings) = &self.position_embeddings {
  247. // TODO: Proper absolute positions?
  248. let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
  249. let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
  250. embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
  251. }
  252. let embeddings = self.layer_norm.forward(&embeddings)?;
  253. let embeddings = self.dropout.forward(&embeddings)?;
  254. Ok(embeddings)
  255. }
  256. }
  257. struct BertSelfAttention {
  258. query: Linear,
  259. key: Linear,
  260. value: Linear,
  261. dropout: Dropout,
  262. num_attention_heads: usize,
  263. attention_head_size: usize,
  264. span: tracing::Span,
  265. span_softmax: tracing::Span,
  266. }
  267. impl BertSelfAttention {
  268. fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
  269. let attention_head_size = config.hidden_size / config.num_attention_heads;
  270. let all_head_size = config.num_attention_heads * attention_head_size;
  271. let dropout = Dropout::new(config.hidden_dropout_prob);
  272. let hidden_size = config.hidden_size;
  273. let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
  274. let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
  275. let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
  276. Ok(Self {
  277. query,
  278. key,
  279. value,
  280. dropout,
  281. num_attention_heads: config.num_attention_heads,
  282. attention_head_size,
  283. span: tracing::span!(tracing::Level::TRACE, "self-attn"),
  284. span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
  285. })
  286. }
  287. fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
  288. let mut new_x_shape = xs.dims().to_vec();
  289. new_x_shape.pop();
  290. new_x_shape.push(self.num_attention_heads);
  291. new_x_shape.push(self.attention_head_size);
  292. let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
  293. xs.contiguous()
  294. }
  295. fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
  296. let _enter = self.span.enter();
  297. let query_layer = self.query.forward(hidden_states)?;
  298. let key_layer = self.key.forward(hidden_states)?;
  299. let value_layer = self.value.forward(hidden_states)?;
  300. let query_layer = self.transpose_for_scores(&query_layer)?;
  301. let key_layer = self.transpose_for_scores(&key_layer)?;
  302. let value_layer = self.transpose_for_scores(&value_layer)?;
  303. let attention_scores = query_layer.matmul(&key_layer.t()?)?;
  304. let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
  305. let attention_probs = {
  306. let _enter_sm = self.span_softmax.enter();
  307. candle_nn::ops::softmax(&attention_scores, candle_core::D::Minus1)?
  308. };
  309. let attention_probs = self.dropout.forward(&attention_probs)?;
  310. let context_layer = attention_probs.matmul(&value_layer)?;
  311. let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
  312. let context_layer = context_layer.flatten_from(candle_core::D::Minus2)?;
  313. Ok(context_layer)
  314. }
  315. }
  316. struct BertSelfOutput {
  317. dense: Linear,
  318. layer_norm: LayerNorm,
  319. dropout: Dropout,
  320. span: tracing::Span,
  321. }
  322. impl BertSelfOutput {
  323. fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
  324. let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
  325. let layer_norm = layer_norm(
  326. config.hidden_size,
  327. config.layer_norm_eps,
  328. vb.pp("LayerNorm"),
  329. )?;
  330. let dropout = Dropout::new(config.hidden_dropout_prob);
  331. Ok(Self {
  332. dense,
  333. layer_norm,
  334. dropout,
  335. span: tracing::span!(tracing::Level::TRACE, "self-out"),
  336. })
  337. }
  338. fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
  339. let _enter = self.span.enter();
  340. let hidden_states = self.dense.forward(hidden_states)?;
  341. let hidden_states = self.dropout.forward(&hidden_states)?;
  342. self.layer_norm.forward(&(hidden_states + input_tensor)?)
  343. }
  344. }
  345. // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392
  346. struct BertAttention {
  347. self_attention: BertSelfAttention,
  348. self_output: BertSelfOutput,
  349. span: tracing::Span,
  350. }
  351. impl BertAttention {
  352. fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
  353. let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
  354. let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
  355. Ok(Self {
  356. self_attention,
  357. self_output,
  358. span: tracing::span!(tracing::Level::TRACE, "attn"),
  359. })
  360. }
  361. fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
  362. let _enter = self.span.enter();
  363. let self_outputs = self.self_attention.forward(hidden_states)?;
  364. let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
  365. Ok(attention_output)
  366. }
  367. }
  368. // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441
  369. struct BertIntermediate {
  370. dense: Linear,
  371. intermediate_act: HiddenActLayer,
  372. span: tracing::Span,
  373. }
  374. impl BertIntermediate {
  375. fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
  376. let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
  377. Ok(Self {
  378. dense,
  379. intermediate_act: HiddenActLayer::new(config.hidden_act),
  380. span: tracing::span!(tracing::Level::TRACE, "inter"),
  381. })
  382. }
  383. fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
  384. let _enter = self.span.enter();
  385. let hidden_states = self.dense.forward(hidden_states)?;
  386. let ys = self.intermediate_act.forward(&hidden_states)?;
  387. Ok(ys)
  388. }
  389. }
  390. // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456
  391. struct BertOutput {
  392. dense: Linear,
  393. layer_norm: LayerNorm,
  394. dropout: Dropout,
  395. span: tracing::Span,
  396. }
  397. impl BertOutput {
  398. fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
  399. let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
  400. let layer_norm = layer_norm(
  401. config.hidden_size,
  402. config.layer_norm_eps,
  403. vb.pp("LayerNorm"),
  404. )?;
  405. let dropout = Dropout::new(config.hidden_dropout_prob);
  406. Ok(Self {
  407. dense,
  408. layer_norm,
  409. dropout,
  410. span: tracing::span!(tracing::Level::TRACE, "out"),
  411. })
  412. }
  413. fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
  414. let _enter = self.span.enter();
  415. let hidden_states = self.dense.forward(hidden_states)?;
  416. let hidden_states = self.dropout.forward(&hidden_states)?;
  417. self.layer_norm.forward(&(hidden_states + input_tensor)?)
  418. }
  419. }
  420. // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470
  421. struct BertLayer {
  422. attention: BertAttention,
  423. intermediate: BertIntermediate,
  424. output: BertOutput,
  425. span: tracing::Span,
  426. }
  427. impl BertLayer {
  428. fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
  429. let attention = BertAttention::load(vb.pp("attention"), config)?;
  430. let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
  431. let output = BertOutput::load(vb.pp("output"), config)?;
  432. Ok(Self {
  433. attention,
  434. intermediate,
  435. output,
  436. span: tracing::span!(tracing::Level::TRACE, "layer"),
  437. })
  438. }
  439. fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
  440. let _enter = self.span.enter();
  441. let attention_output = self.attention.forward(hidden_states)?;
  442. // TODO: Support cross-attention?
  443. // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
  444. // TODO: Support something similar to `apply_chunking_to_forward`?
  445. let intermediate_output = self.intermediate.forward(&attention_output)?;
  446. let layer_output = self
  447. .output
  448. .forward(&intermediate_output, &attention_output)?;
  449. Ok(layer_output)
  450. }
  451. }
  452. // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
  453. struct BertEncoder {
  454. layers: Vec<BertLayer>,
  455. span: tracing::Span,
  456. }
  457. impl BertEncoder {
  458. fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
  459. let layers = (0..config.num_hidden_layers)
  460. .map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config))
  461. .collect::<Result<Vec<_>>>()?;
  462. let span = tracing::span!(tracing::Level::TRACE, "encoder");
  463. Ok(BertEncoder { layers, span })
  464. }
  465. fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
  466. let _enter = self.span.enter();
  467. let mut hidden_states = hidden_states.clone();
  468. // Use a loop rather than a fold as it's easier to modify when adding debug/...
  469. for layer in self.layers.iter() {
  470. hidden_states = layer.forward(&hidden_states)?
  471. }
  472. Ok(hidden_states)
  473. }
  474. }
  475. // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
  476. pub struct BertModel {
  477. embeddings: BertEmbeddings,
  478. encoder: BertEncoder,
  479. pub device: Device,
  480. span: tracing::Span,
  481. }
  482. impl BertModel {
  483. pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
  484. let (embeddings, encoder) = match (
  485. BertEmbeddings::load(vb.pp("embeddings"), config),
  486. BertEncoder::load(vb.pp("encoder"), config),
  487. ) {
  488. (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
  489. (Err(err), _) | (_, Err(err)) => {
  490. if let Some(model_type) = &config.model_type {
  491. if let (Ok(embeddings), Ok(encoder)) = (
  492. BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
  493. BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config),
  494. ) {
  495. (embeddings, encoder)
  496. } else {
  497. return Err(err);
  498. }
  499. } else {
  500. return Err(err);
  501. }
  502. }
  503. };
  504. Ok(Self {
  505. embeddings,
  506. encoder,
  507. device: vb.device().clone(),
  508. span: tracing::span!(tracing::Level::TRACE, "model"),
  509. })
  510. }
  511. pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
  512. let _enter = self.span.enter();
  513. let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
  514. let sequence_output = self.encoder.forward(&embedding_output)?;
  515. Ok(sequence_output)
  516. }
  517. }