1use super::nonlinear_gates::*;
6use super::AdderGateTrait;
7use crate::abstract_expr::AbstractExpression;
8use crate::components::binary_operations::binary_adder::BinaryAdder;
9use crate::layouter::builder::{Circuit, CircuitBuilder, InputLayerNodeRef, NodeRef};
10use itertools::Itertools;
11use remainder::mle::evals::MultilinearExtension;
12use shared_types::Field;
13use std::marker::PhantomData;
14use std::ops::Index;
15use std::sync::atomic::{AtomicU64, Ordering};
16
17pub const WORD_SIZE: usize = 32;
19
20pub type Sigma0<F> = Sigma<F, WORD_SIZE, 2, 13, 22>;
23
24pub type Sigma1<F> = Sigma<F, WORD_SIZE, 6, 11, 25>;
26
27pub type SmallSigma0<F> = SmallSigma<F, WORD_SIZE, 7, 18, 3>;
29
30pub type SmallSigma1<F> = SmallSigma<F, WORD_SIZE, 17, 19, 10>;
32
33pub type Sha256Adder<F> = CommittedCarryAdder<F>;
35
36fn carry_name_counter() -> String {
37 static CARRY_NAME_COUNTER: AtomicU64 = AtomicU64::new(!0);
38 format!("{:016x}", CARRY_NAME_COUNTER.fetch_add(1, Ordering::SeqCst))
39}
40
41fn add_get_carry_bits_lsb(x: u32, y: u32, mut c_in: u32) -> (u32, Vec<u32>) {
42 debug_assert!(c_in == 0 || c_in == 1);
43 let x_vec = bit_decompose_lsb_first(x);
44 let y_vec = bit_decompose_lsb_first(y);
45 let mut carry = Vec::<u32>::with_capacity(33);
46 let mut sum = 0x0u32;
47 for (i, (x, y)) in x_vec.into_iter().zip(y_vec.into_iter()).enumerate() {
48 let this_sum = x + y + c_in;
49 c_in = this_sum / 2;
50 carry.push(c_in);
51 sum = sum.wrapping_add((this_sum % 2) << i);
52 }
53 (sum, carry)
54}
55
56fn add_get_carry_bits_msb(x: u32, y: u32, c_in: u32) -> (u32, Vec<u32>) {
57 let (s, mut c) = add_get_carry_bits_lsb(x, y, c_in);
58 c.reverse();
59 (s, c)
60}
61
62#[derive(Debug, Clone)]
65pub struct CommittedCarryAdder<F: Field> {
66 carry_shred_name: String,
68 sum_node: NodeRef<F>,
70 _phantom: PhantomData<F>,
71}
72
73impl<F: Field> AdderGateTrait<F> for CommittedCarryAdder<F> {
74 type IntegralType = u32;
75
76 fn layout_adder_circuit(
77 circuit_builder: &mut CircuitBuilder<F>, x_node: &NodeRef<F>, y_node: &NodeRef<F>, carry_layer: Option<InputLayerNodeRef<F>>, ) -> Self {
82 assert!(
83 carry_layer.is_some(),
84 " Committed carry requires a carry layer"
85 );
86 CommittedCarryAdder::new(
87 circuit_builder,
88 carry_layer.as_ref().unwrap(),
89 x_node,
90 y_node,
91 )
92 }
93
94 fn get_output(&self) -> NodeRef<F> {
96 self.sum_node.clone()
97 }
98
99 fn perform_addition(
100 &self,
101 circuit: &mut Circuit<F>,
102 x: Self::IntegralType,
103 y: Self::IntegralType,
104 ) -> Self::IntegralType {
105 self.populate_carry(circuit, x, y)
106 }
107}
108
109impl<F: Field> CommittedCarryAdder<F> {
110 pub fn new(
113 ckt_builder: &mut CircuitBuilder<F>,
114 carry_layer: &InputLayerNodeRef<F>,
115 x_node: &NodeRef<F>,
116 y_node: &NodeRef<F>,
117 ) -> Self {
118 debug_assert_eq!(x_node.get_num_vars(), 5);
119 debug_assert_eq!(y_node.get_num_vars(), 5);
120
121 let carry_shred_name = carry_name_counter();
123 let carry_shred =
124 ckt_builder.add_input_shred(&carry_shred_name, x_node.get_num_vars(), carry_layer);
125
126 let b_sq = AbstractExpression::products(vec![carry_shred.id(), carry_shred.id()]);
128 let b = carry_shred.expr();
129
130 let binary_sector = ckt_builder.add_sector(
132 b - b_sq,
134 );
135
136 ckt_builder.set_output(&binary_sector);
137
138 let binary_adder = BinaryAdder::new(ckt_builder, x_node, y_node, &carry_shred);
139
140 Self {
141 carry_shred_name,
142 sum_node: binary_adder.get_output(),
143 _phantom: Default::default(),
144 }
145 }
146
147 pub fn populate_carry(&self, circuit: &mut Circuit<F>, x_val: u32, y_val: u32) -> u32 {
150 let (s, carries) = add_get_carry_bits_msb(x_val, y_val, 0);
151 debug_assert_eq!(s, x_val.wrapping_add(y_val));
152
153 let carry_mle = MultilinearExtension::new(
154 carries
155 .into_iter()
156 .map(u64::from)
157 .map(F::from)
158 .collect_vec(),
159 );
160 circuit.set_input(&self.carry_shred_name, carry_mle);
161 s
162 }
163}
164
165#[derive(Debug, Clone)]
166struct MessageScheduleAdderTree<Adder> {
167 sum_a_leaf: Adder, sum_b_leaf: Adder, sum_a_b: Adder, }
171
172#[derive(Debug, Clone)]
173struct MessageScheduleState<F: Field, Adder> {
174 state_node: NodeRef<F>,
175 state_adders: Option<MessageScheduleAdderTree<Adder>>,
176}
177
178pub struct MessageSchedule<F: Field, Adder> {
182 msg_schedule: Vec<MessageScheduleState<F, Adder>>,
183 _phantom: PhantomData<F>,
184}
185
186impl<F, Adder> MessageSchedule<F, Adder>
187where
188 F: Field,
189 Adder: AdderGateTrait<F, IntegralType = u32> + Clone,
190{
191 pub fn new(
195 ckt_builder: &mut CircuitBuilder<F>,
196 carry_layer: Option<&InputLayerNodeRef<F>>,
197 msg_vars: &[NodeRef<F>],
198 ) -> Self {
199 debug_assert_eq!(msg_vars.len(), 16);
200
201 let mut state: Vec<MessageScheduleState<F, Adder>> = Vec::with_capacity(64);
202
203 (0..16).for_each(|i| {
204 debug_assert!(msg_vars[i].get_num_vars() == 5);
205 state.push(MessageScheduleState {
206 state_node: msg_vars[i].clone(),
207 state_adders: None,
208 })
209 });
210
211 for t in 16..64 {
212 assert!(state.len() >= t - 16);
213 let small_sigma_1_val = SmallSigma1::new(ckt_builder, &state[t - 2].state_node);
214 let w_first = state[t - 7].state_node.clone();
215
216 let small_sigma_0_val = SmallSigma0::new(ckt_builder, &state[t - 15].state_node);
217 let w_second = state[t - 16].state_node.clone();
218
219 let sum_a_leaf = Adder::layout_adder_circuit(
220 ckt_builder,
221 &small_sigma_1_val.get_output(),
222 &w_first,
223 carry_layer.cloned(),
224 );
225
226 let sum_b_leaf = Adder::layout_adder_circuit(
227 ckt_builder,
228 &small_sigma_0_val.get_output(),
229 &w_second,
230 carry_layer.cloned(),
231 );
232
233 let sum_a_b = Adder::layout_adder_circuit(
234 ckt_builder,
235 &sum_a_leaf.get_output(),
236 &sum_b_leaf.get_output(),
237 carry_layer.cloned(),
238 );
239
240 let current_state = MessageScheduleState {
241 state_node: sum_a_b.get_output(),
242 state_adders: Some(MessageScheduleAdderTree {
243 sum_a_leaf,
244 sum_b_leaf,
245 sum_a_b,
246 }),
247 };
248
249 state.push(current_state);
250 }
251
252 Self {
253 msg_schedule: state,
254 _phantom: Default::default(),
255 }
256 }
257
258 pub fn get_output_nodes(&self) -> Vec<NodeRef<F>> {
260 self.msg_schedule
261 .iter()
262 .map(|st| st.state_node.clone())
263 .collect()
264 }
265
266 pub fn populate_message_schedule(
269 &self,
270 circuit: &mut Circuit<F>,
271 input_data: &[u32],
272 ) -> Vec<u32> {
273 debug_assert_eq!(input_data.len(), 16);
274 let mut state: Vec<u32> = Vec::with_capacity(64);
275
276 (0..16).for_each(|i| state.push(input_data[i]));
277
278 for t in 16..64 {
279 assert!(state.len() >= t - 16);
280 let small_sigma_1_val = SmallSigma1::<F>::evaluate(state[t - 2]);
281 let w_first = state[t - 7];
282
283 let small_sigma_0_val = SmallSigma0::<F>::evaluate(state[t - 15]);
284 let w_second = state[t - 16];
285
286 let add_tree = self.msg_schedule[t].state_adders.clone().unwrap();
287
288 let sum_a_value =
289 add_tree
290 .sum_a_leaf
291 .perform_addition(circuit, small_sigma_1_val, w_first);
292
293 let sum_b_value =
294 add_tree
295 .sum_b_leaf
296 .perform_addition(circuit, small_sigma_0_val, w_second);
297
298 let sum_a_b_value =
299 add_tree
300 .sum_a_b
301 .perform_addition(circuit, sum_a_value, sum_b_value);
302
303 debug_assert_eq!(
304 sum_a_b_value,
305 small_sigma_1_val
306 .wrapping_add(w_first)
307 .wrapping_add(small_sigma_0_val)
308 .wrapping_add(w_second)
309 );
310 state.push(sum_a_b_value);
311 }
312
313 state
314 }
315
316 pub fn get_output_expr(&self) -> AbstractExpression<F> {
319 AbstractExpression::<F>::binary_tree_selector(
320 self.msg_schedule
321 .iter()
322 .map(|st| st.state_node.expr())
323 .collect(),
324 )
325 }
326}
327
328pub struct KeySchedule<F: Field> {
330 keys: Vec<ConstInputGate<F>>,
331}
332
333pub struct HConstants<F: Field> {
335 ivs: Vec<ConstInputGate<F>>,
336}
337
338impl<F: Field> HConstants<F> {
339 const H: [u32; 8] = [
340 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab,
341 0x5be0cd19,
342 ];
343
344 fn iv_name(index: usize) -> String {
345 format!("sha256-iv-{index:x}")
346 }
347
348 pub fn new(ckt_builder: &mut CircuitBuilder<F>) -> Self {
350 Self {
351 ivs: Self::H
352 .iter()
353 .enumerate()
354 .map(|(i, val)| ConstInputGate::new(ckt_builder, &Self::iv_name(i), *val))
355 .collect(),
356 }
357 }
358
359 pub fn populate_iv(&self, circuit: &mut Circuit<F>) {
361 for (ndx, const_iv) in self.ivs.iter().enumerate() {
362 circuit.set_input(&Self::iv_name(ndx), const_iv.input_mle().clone());
363 }
364 }
365
366 pub fn get_output_nodes(&self) -> Vec<NodeRef<F>> {
368 self.ivs.iter().map(|v| v.get_output()).collect()
369 }
370}
371
372impl<F: Field> KeySchedule<F> {
373 const ROUND_KEYS: [u32; 64] = [
374 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4,
375 0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe,
376 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f,
377 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
378 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc,
379 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b,
380 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, 0x19a4c116,
381 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
382 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7,
383 0xc67178f2,
384 ];
385
386 fn key_name(index: usize) -> String {
387 format!("sha256-key-schedule-{index}")
388 }
389
390 pub fn new(ckt_builder: &mut CircuitBuilder<F>) -> Self {
392 Self {
393 keys: Self::ROUND_KEYS
394 .iter()
395 .enumerate()
396 .map(|(i, val)| ConstInputGate::new(ckt_builder, &Self::key_name(i), *val))
397 .collect(),
398 }
399 }
400
401 pub fn populate_key_schedule(&self, circuit: &mut Circuit<F>) {
403 for (i, key_const) in self.keys.iter().enumerate() {
404 circuit.set_input(&Self::key_name(i), key_const.input_mle().clone());
405 }
406 }
407}
408
409impl<F: Field> Index<usize> for KeySchedule<F> {
410 type Output = NodeRef<F>;
411
412 fn index(&self, index: usize) -> &Self::Output {
413 self.keys[index].get_output_ref()
414 }
415}
416
417struct CompressionFnRoundCarries<Adder> {
418 t1_carries: [Adder; 4], t2_carries: Adder, e_carries: Adder, a_carries: Adder, }
423
424pub struct CompressionFn<F, Adder> {
426 output: Vec<Adder>,
427 round_carries: Vec<CompressionFnRoundCarries<Adder>>,
428 _phantom: PhantomData<F>,
429}
430
431impl<F, Adder> CompressionFn<F, Adder>
432where
433 F: Field,
434 Adder: AdderGateTrait<F, IntegralType = u32> + Clone,
435{
436 pub fn new(
441 ckt_builder: &mut CircuitBuilder<F>,
442 carry_layer: Option<&InputLayerNodeRef<F>>,
443 msg_schedule: &MessageSchedule<F, Adder>, input_schedule: &[NodeRef<F>], round_keys: &KeySchedule<F>, ) -> Self {
447 let msg_schedule = msg_schedule.get_output_nodes();
448 debug_assert_eq!(msg_schedule.len(), 64);
449 debug_assert_eq!(input_schedule.len(), 8);
450 (0..64).for_each(|i| debug_assert_eq!(msg_schedule[i].get_num_vars(), 5));
451 (0..8).for_each(|i| debug_assert_eq!(input_schedule[i].get_num_vars(), 5));
452
453 let mut round_carries: Vec<CompressionFnRoundCarries<Adder>> = Vec::with_capacity(64);
454
455 let mut a = input_schedule[0].clone();
456 let mut b = input_schedule[1].clone();
457 let mut c = input_schedule[2].clone();
458 let mut d = input_schedule[3].clone();
459 let mut e = input_schedule[4].clone();
460 let mut f = input_schedule[5].clone();
461 let mut g = input_schedule[6].clone();
462 let mut h = input_schedule[7].clone();
463
464 for t in 0..64 {
465 let w_t = &msg_schedule[t];
466 let k_t = &round_keys[t];
467 let t1 = Self::compute_t1(ckt_builder, carry_layer, &e, &f, &g, &h, w_t, k_t);
468 let t2 = Self::compute_t2(ckt_builder, carry_layer, &a, &b, &c);
469 h = ckt_builder.add_sector(g.expr());
470 g = ckt_builder.add_sector(f.expr());
471 f = ckt_builder.add_sector(e.expr());
472 let e_sum = Adder::layout_adder_circuit(
473 ckt_builder,
474 &d,
475 &t1.last().unwrap().get_output(),
476 carry_layer.cloned(),
477 );
478 e = e_sum.get_output();
479 d = ckt_builder.add_sector(c.expr());
480 c = ckt_builder.add_sector(b.expr());
481 b = ckt_builder.add_sector(a.expr());
482 let a_sum = Adder::layout_adder_circuit(
483 ckt_builder,
484 &t1.last().unwrap().get_output(),
485 &t2.get_output(),
486 carry_layer.cloned(),
487 );
488 a = a_sum.get_output();
489 round_carries.push(CompressionFnRoundCarries {
490 t1_carries: t1,
491 t2_carries: t2,
492 e_carries: e_sum,
493 a_carries: a_sum,
494 });
495 }
496
497 let intermediates = [a, b, c, d, e, f, g, h];
498
499 let output = input_schedule
500 .iter()
501 .zip(intermediates.iter())
502 .map(|(h, x)| Adder::layout_adder_circuit(ckt_builder, h, x, carry_layer.cloned()))
503 .collect();
504
505 Self {
506 output,
507 round_carries,
508 _phantom: Default::default(),
509 }
510 }
511
512 #[allow(clippy::too_many_arguments)]
513 fn compute_t1(
514 ckt_builder: &mut CircuitBuilder<F>,
515 carry_layer: Option<&InputLayerNodeRef<F>>,
516 e: &NodeRef<F>,
517 f: &NodeRef<F>,
518 g: &NodeRef<F>,
519 h: &NodeRef<F>,
520 w_t: &NodeRef<F>,
521 k_t: &NodeRef<F>,
522 ) -> [Adder; 4] {
523 debug_assert!(e.get_num_vars() == 5);
524 debug_assert!(f.get_num_vars() == 5);
525 debug_assert!(g.get_num_vars() == 5);
526 debug_assert!(h.get_num_vars() == 5);
527 debug_assert!(w_t.get_num_vars() == 5);
528 debug_assert!(k_t.get_num_vars() == 5);
529
530 let t1_sigma_1 = Sigma1::new(ckt_builder, e);
531 let t1_ch = ChGate::new(ckt_builder, e, f, g);
532
533 let sum1 = Adder::layout_adder_circuit(
535 ckt_builder,
536 h,
537 &t1_sigma_1.get_output(),
538 carry_layer.cloned(),
539 );
540
541 let sum2 = Adder::layout_adder_circuit(
543 ckt_builder,
544 &t1_ch.get_output(),
545 k_t,
546 carry_layer.cloned(),
547 );
548
549 let sum3 = Adder::layout_adder_circuit(
551 ckt_builder,
552 &sum1.get_output(),
553 &sum2.get_output(),
554 carry_layer.cloned(),
555 );
556
557 let sum4 =
559 Adder::layout_adder_circuit(ckt_builder, &sum3.get_output(), w_t, carry_layer.cloned());
560
561 [sum1, sum2, sum3, sum4]
562 }
563
564 fn compute_t2(
565 ckt_builder: &mut CircuitBuilder<F>,
566 carry_layer: Option<&InputLayerNodeRef<F>>,
567 a: &NodeRef<F>,
568 b: &NodeRef<F>,
569 c: &NodeRef<F>,
570 ) -> Adder {
571 let s1 = Sigma0::new(ckt_builder, a);
572 let m1 = MajGate::new(ckt_builder, a, b, c);
573 Adder::layout_adder_circuit(
574 ckt_builder,
575 &s1.get_output(),
576 &m1.get_output(),
577 carry_layer.cloned(),
578 )
579 }
580
581 pub fn populate_compression_fn(
597 &self,
598 circuit: &mut Circuit<F>,
599 message_words: Vec<u32>,
600 input_words: Vec<u32>,
601 ) -> Vec<u32> {
602 let mut a = input_words[0];
603 let mut b = input_words[1];
604 let mut c = input_words[2];
605 let mut d = input_words[3];
606 let mut e = input_words[4];
607 let mut f = input_words[5];
608 let mut g = input_words[6];
609 let mut h = input_words[7];
610
611 for (t, w_t) in message_words.iter().enumerate().take(64) {
624 let k_t = KeySchedule::<F>::ROUND_KEYS[t];
625 let [sum1, sum2, sum3, sum4] = &self.round_carries[t].t1_carries;
626 let t1 = Self::populate_t1(circuit, sum1, sum2, sum3, sum4, e, f, g, h, *w_t, k_t);
627 let t2 = Self::populate_t2(circuit, &self.round_carries[t].t2_carries, a, b, c);
628 h = g;
629 g = f;
630 f = e;
631 e = self.round_carries[t]
632 .e_carries
633 .perform_addition(circuit, d, t1);
634 d = c;
635 c = b;
636 b = a;
637 a = self.round_carries[t]
638 .a_carries
639 .perform_addition(circuit, t1, t2);
640
641 }
647
648 let intermediates = [a, b, c, d, e, f, g, h];
649
650 input_words
651 .iter()
652 .zip(intermediates.iter())
653 .zip(&self.output)
654 .map(|((h, hprime), gate)| gate.perform_addition(circuit, *h, *hprime))
655 .collect()
656 }
657
658 #[allow(clippy::too_many_arguments)]
659 fn populate_t1(
660 circuit: &mut Circuit<F>,
661 sum1: &Adder,
662 sum2: &Adder,
663 sum3: &Adder,
664 sum4: &Adder,
665 e: u32,
666 f: u32,
667 g: u32,
668 h: u32,
669 w_t: u32,
670 k_t: u32,
671 ) -> u32 {
672 let t1_sigma_1 = Sigma1::<F>::evaluate(e);
673 let t1_ch = ChGate::<F>::evaluate(e, f, g);
674
675 let sum1 = sum1.perform_addition(circuit, h, t1_sigma_1);
677
678 let sum2 = sum2.perform_addition(circuit, t1_ch, k_t);
680
681 let sum3 = sum3.perform_addition(circuit, sum1, sum2);
683
684 sum4.perform_addition(circuit, sum3, w_t)
686 }
687
688 fn populate_t2(circuit: &mut Circuit<F>, sum1: &Adder, a: u32, b: u32, c: u32) -> u32 {
689 let s1 = Sigma0::<F>::evaluate(a);
690 let m1 = MajGate::<F>::evaluate(a, b, c);
691 sum1.perform_addition(circuit, s1, m1)
692 }
693
694 pub fn get_output_nodes(&self) -> Vec<NodeRef<F>> {
697 self.output.iter().map(|n| n.get_output()).collect()
698 }
699}
700
701fn sha256_padded_input(mut input_data: Vec<u8>) -> Vec<u32> {
702 let input_len_bits = input_data.len() * 8;
703 let pad_bits = 448 - ((input_len_bits + 1) % 512);
704 let zero_bytes = pad_bits / 8;
705 input_data.push(0x80);
706 input_data.extend(std::iter::repeat_n(0_u8, zero_bytes));
707 input_data.extend_from_slice(input_len_bits.to_be_bytes().as_slice());
708 assert!(input_data.len().is_multiple_of(64));
709
710 input_data
711 .as_slice()
712 .chunks_exact(4)
713 .map(|chunk| chunk.iter().fold(0u32, |acc, v| (acc << 8) + (*v as u32)))
714 .collect()
715}
716
717struct Sha256State<F: Field, Adder> {
718 message_schedule: MessageSchedule<F, Adder>, compression_fn: CompressionFn<F, Adder>,
720 input_chunks: Vec<u32>, }
722
723pub struct Sha256<F: Field, Adder> {
725 key_schedule: KeySchedule<F>,
726 init_iv: HConstants<F>,
727 round_states: Vec<Sha256State<F, Adder>>,
728}
729
730impl<F, Adder> Sha256<F, Adder>
731where
732 F: Field,
733 Adder: AdderGateTrait<F, IntegralType = u32> + Clone,
734{
735 pub fn new(
737 ckt_builder: &mut CircuitBuilder<F>,
738 data_input_layer: &InputLayerNodeRef<F>,
739 carry_layer: Option<&InputLayerNodeRef<F>>,
740 input_data: Vec<u8>,
741 ) -> Self {
742 let key_schedule = KeySchedule::<F>::new(ckt_builder); let init_iv = HConstants::<F>::new(ckt_builder); let input_data = sha256_padded_input(input_data);
745 let num_vars = input_data.len().ilog2() as usize + 5; let all_input = ckt_builder.add_input_shred("SHA256_input", num_vars, data_input_layer);
747
748 let b = &all_input;
749 let b_sq = AbstractExpression::products(vec![b.id(), b.id()]);
750 let b = b.expr();
751
752 let binary_sector = ckt_builder.add_sector(
754 b - b_sq,
756 );
757
758 ckt_builder.set_output(&binary_sector);
760
761 let mut input_schedule = init_iv.get_output_nodes();
762
763 let input_word_splits =
764 ckt_builder.add_split_node(&all_input, input_data.len().ilog2() as usize);
765
766 debug_assert_eq!(input_word_splits.len(), input_data.len());
767
768 let round_states = input_word_splits
769 .as_slice()
770 .chunks_exact(16)
771 .zip(input_data.as_slice().chunks_exact(16))
772 .map(|(msg_vars, data)| {
773 let message_schedule = MessageSchedule::new(ckt_builder, carry_layer, msg_vars);
774 let ckt = CompressionFn::new(
775 ckt_builder,
776 carry_layer,
777 &message_schedule,
778 &input_schedule,
779 &key_schedule,
780 );
781 input_schedule = ckt.get_output_nodes();
782
783 Sha256State {
784 message_schedule,
785 input_chunks: data.to_vec(),
786 compression_fn: ckt,
787 }
788 })
789 .collect_vec();
790 Self {
791 key_schedule,
792 round_states,
793 init_iv,
794 }
795 }
796
797 pub fn padded_data_chunks(&self) -> Vec<u32> {
799 self.round_states
800 .iter()
801 .flat_map(|st| st.input_chunks.clone())
802 .collect()
803 }
804
805 pub fn get_output_node(&self) -> Vec<NodeRef<F>> {
807 self.round_states
808 .last()
809 .map(|v| v.compression_fn.get_output_nodes())
810 .unwrap()
811 }
812
813 fn populate_state(
815 &self,
816 circuit: &mut Circuit<F>,
817 state: &Sha256State<F, Adder>,
818 input_words: Vec<u32>,
819 ) -> Vec<u32> {
820 let message_words = state
821 .message_schedule
822 .populate_message_schedule(circuit, &state.input_chunks);
823 state
824 .compression_fn
825 .populate_compression_fn(circuit, message_words, input_words)
826 }
827
828 pub fn populate_circuit(&self, circuit: &mut Circuit<F>) -> Vec<u32> {
831 let all_bits = self
832 .round_states
833 .iter()
834 .flat_map(|st| st.input_chunks.iter().map(|v| bit_decompose_msb_first(*v)))
835 .flatten()
836 .map(u64::from)
837 .map(F::from)
838 .collect_vec();
839
840 self.key_schedule.populate_key_schedule(circuit);
841 self.init_iv.populate_iv(circuit);
842
843 let final_result = self
844 .round_states
845 .iter()
846 .fold(HConstants::<F>::H.to_vec(), |hash_val, st| {
847 self.populate_state(circuit, st, hash_val)
848 });
849
850 let input_mle = MultilinearExtension::new(all_bits);
851 circuit.set_input("SHA256_input", input_mle);
852 final_result
853 }
854}
855
856#[cfg(test)]
857mod tests {
858 #[test]
859 fn test_bit_decomp_carry() {
860 let x = 0xffffffffu32;
861 let y = 0x1u32;
862 let z = 0x80000000u32;
863
864 #[rustfmt::skip]
865 let carry_x_plus_y = vec![
866 1, 1, 1, 1, 1, 1, 1, 1,
867 1, 1, 1, 1, 1, 1, 1, 1,
868 1, 1, 1, 1, 1, 1, 1, 1,
869 1, 1, 1, 1, 1, 1, 1, 1,
870 ];
871
872 let (s, c) = super::add_get_carry_bits_lsb(x, y, 0);
873 assert_eq!(s, x.wrapping_add(y));
874 assert_eq!(c, carry_x_plus_y);
875
876 let (s, c) = super::add_get_carry_bits_lsb(x, y, 1);
877 assert_eq!(s, x.wrapping_add(y).wrapping_add(1));
878 assert_eq!(c, carry_x_plus_y);
879
880 #[rustfmt::skip]
881 let carry_x_plus_z = vec![
882 0, 0, 0, 0, 0, 0, 0, 0,
883 0, 0, 0, 0, 0, 0, 0, 0,
884 0, 0, 0, 0, 0, 0, 0, 0,
885 0, 0, 0, 0, 0, 0, 0, 1
886 ];
887
888 let (s, c) = super::add_get_carry_bits_lsb(x, z, 0);
889 assert_eq!(s, x.wrapping_add(z));
890 assert_eq!(c, carry_x_plus_z);
891
892 #[rustfmt::skip]
893 let carry_x_plus_z = vec![
894 1, 1, 1, 1, 1, 1, 1, 1,
895 1, 1, 1, 1, 1, 1, 1, 1,
896 1, 1, 1, 1, 1, 1, 1, 1,
897 1, 1, 1, 1, 1, 1, 1, 1,
898 ];
899
900 let (s, c) = super::add_get_carry_bits_lsb(x, z, 1);
901 assert_eq!(s, x.wrapping_add(z).wrapping_add(1));
902 assert_eq!(c, carry_x_plus_z);
903 }
904}