frontend/components/sha2_gkr/
sha256_bit_decomp.rs

1//!
2//! Implementation of SHA-256 circuit using bitwise decomposition
3//!
4
5use super::nonlinear_gates::*;
6use super::AdderGateTrait;
7use crate::abstract_expr::AbstractExpression;
8use crate::components::binary_operations::binary_adder::BinaryAdder;
9use crate::layouter::builder::{Circuit, CircuitBuilder, InputLayerNodeRef, NodeRef};
10use itertools::Itertools;
11use remainder::mle::evals::MultilinearExtension;
12use shared_types::Field;
13use std::marker::PhantomData;
14use std::ops::Index;
15use std::sync::atomic::{AtomicU64, Ordering};
16
17/// SHA256 Works with word size 32
18pub const WORD_SIZE: usize = 32;
19
20/// See <https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf> for details about constants
21/// Sigma0 of SHA-256
22pub type Sigma0<F> = Sigma<F, WORD_SIZE, 2, 13, 22>;
23
24/// Sigma0 of SHA-256
25pub type Sigma1<F> = Sigma<F, WORD_SIZE, 6, 11, 25>;
26
27/// Little Sigma0
28pub type SmallSigma0<F> = SmallSigma<F, WORD_SIZE, 7, 18, 3>;
29
30/// Little Sigma1
31pub type SmallSigma1<F> = SmallSigma<F, WORD_SIZE, 17, 19, 10>;
32
33/// Specific adder for SHA256
34pub type Sha256Adder<F> = CommittedCarryAdder<F>;
35
36fn carry_name_counter() -> String {
37    static CARRY_NAME_COUNTER: AtomicU64 = AtomicU64::new(!0);
38    format!("{:016x}", CARRY_NAME_COUNTER.fetch_add(1, Ordering::SeqCst))
39}
40
41fn add_get_carry_bits_lsb(x: u32, y: u32, mut c_in: u32) -> (u32, Vec<u32>) {
42    debug_assert!(c_in == 0 || c_in == 1);
43    let x_vec = bit_decompose_lsb_first(x);
44    let y_vec = bit_decompose_lsb_first(y);
45    let mut carry = Vec::<u32>::with_capacity(33);
46    let mut sum = 0x0u32;
47    for (i, (x, y)) in x_vec.into_iter().zip(y_vec.into_iter()).enumerate() {
48        let this_sum = x + y + c_in;
49        c_in = this_sum / 2;
50        carry.push(c_in);
51        sum = sum.wrapping_add((this_sum % 2) << i);
52    }
53    (sum, carry)
54}
55
56fn add_get_carry_bits_msb(x: u32, y: u32, c_in: u32) -> (u32, Vec<u32>) {
57    let (s, mut c) = add_get_carry_bits_lsb(x, y, c_in);
58    c.reverse();
59    (s, c)
60}
61
62/// An adder that just checks the carry bits instead of explicitly
63/// computing it through Ripple Carry Adder.
64#[derive(Debug, Clone)]
65pub struct CommittedCarryAdder<F: Field> {
66    // Automatically generated carry shred name
67    carry_shred_name: String,
68    /// Node representing the sum
69    sum_node: NodeRef<F>,
70    _phantom: PhantomData<F>,
71}
72
73impl<F: Field> AdderGateTrait<F> for CommittedCarryAdder<F> {
74    type IntegralType = u32;
75
76    fn layout_adder_circuit(
77        circuit_builder: &mut CircuitBuilder<F>,   // Circuit builder
78        x_node: &NodeRef<F>,                       // reference to x in x + y
79        y_node: &NodeRef<F>,                       // reference to y in x + y
80        carry_layer: Option<InputLayerNodeRef<F>>, // Carry Layer information
81    ) -> Self {
82        assert!(
83            carry_layer.is_some(),
84            " Committed carry requires a carry layer"
85        );
86        CommittedCarryAdder::new(
87            circuit_builder,
88            carry_layer.as_ref().unwrap(),
89            x_node,
90            y_node,
91        )
92    }
93
94    /// Node representing the output value
95    fn get_output(&self) -> NodeRef<F> {
96        self.sum_node.clone()
97    }
98
99    fn perform_addition(
100        &self,
101        circuit: &mut Circuit<F>,
102        x: Self::IntegralType,
103        y: Self::IntegralType,
104    ) -> Self::IntegralType {
105        self.populate_carry(circuit, x, y)
106    }
107}
108
109impl<F: Field> CommittedCarryAdder<F> {
110    /// Adder that adds a carry shred for carry bits and only checks
111    /// that the sums are correct instead of computing it.
112    pub fn new(
113        ckt_builder: &mut CircuitBuilder<F>,
114        carry_layer: &InputLayerNodeRef<F>,
115        x_node: &NodeRef<F>,
116        y_node: &NodeRef<F>,
117    ) -> Self {
118        debug_assert_eq!(x_node.get_num_vars(), 5);
119        debug_assert_eq!(y_node.get_num_vars(), 5);
120
121        // Probability of collision is 2^{-128}
122        let carry_shred_name = carry_name_counter();
123        let carry_shred =
124            ckt_builder.add_input_shred(&carry_shred_name, x_node.get_num_vars(), carry_layer);
125
126        // Check that all input bits are binary.
127        let b_sq = AbstractExpression::products(vec![carry_shred.id(), carry_shred.id()]);
128        let b = carry_shred.expr();
129
130        // Check that all input bits are binary.
131        let binary_sector = ckt_builder.add_sector(
132            // b * (1 - b) = b - b^2
133            b - b_sq,
134        );
135
136        ckt_builder.set_output(&binary_sector);
137
138        let binary_adder = BinaryAdder::new(ckt_builder, x_node, y_node, &carry_shred);
139
140        Self {
141            carry_shred_name,
142            sum_node: binary_adder.get_output(),
143            _phantom: Default::default(),
144        }
145    }
146
147    /// Given an instantiated circuit, adds gate to the circuit as input
148    /// with correct input label name.
149    pub fn populate_carry(&self, circuit: &mut Circuit<F>, x_val: u32, y_val: u32) -> u32 {
150        let (s, carries) = add_get_carry_bits_msb(x_val, y_val, 0);
151        debug_assert_eq!(s, x_val.wrapping_add(y_val));
152
153        let carry_mle = MultilinearExtension::new(
154            carries
155                .into_iter()
156                .map(u64::from)
157                .map(F::from)
158                .collect_vec(),
159        );
160        circuit.set_input(&self.carry_shred_name, carry_mle);
161        s
162    }
163}
164
165#[derive(Debug, Clone)]
166struct MessageScheduleAdderTree<Adder> {
167    sum_a_leaf: Adder, // A = SmallSigma1(state[t-2]) + state[t - 7]
168    sum_b_leaf: Adder, // B = SmallSigma0(state[t-15]) + state[t - 16]
169    sum_a_b: Adder,    // C = A + B
170}
171
172#[derive(Debug, Clone)]
173struct MessageScheduleState<F: Field, Adder> {
174    state_node: NodeRef<F>,
175    state_adders: Option<MessageScheduleAdderTree<Adder>>,
176}
177
178/// Represents the 64 rounds of message schedule. Each Round consists of
179/// 32 Wires where the first 16 rounds are identity gates, and the rest
180/// are computed as per the spec
181pub struct MessageSchedule<F: Field, Adder> {
182    msg_schedule: Vec<MessageScheduleState<F, Adder>>,
183    _phantom: PhantomData<F>,
184}
185
186impl<F, Adder> MessageSchedule<F, Adder>
187where
188    F: Field,
189    Adder: AdderGateTrait<F, IntegralType = u32> + Clone,
190{
191    /// Given the 16, 32-bit inputs in MBS format, computes the 64
192    /// rounds of message schedule corresponding to the input. The
193    /// `msg_vars` must be the 16 32-bit words decomposed in MBS format.
194    pub fn new(
195        ckt_builder: &mut CircuitBuilder<F>,
196        carry_layer: Option<&InputLayerNodeRef<F>>,
197        msg_vars: &[NodeRef<F>],
198    ) -> Self {
199        debug_assert_eq!(msg_vars.len(), 16);
200
201        let mut state: Vec<MessageScheduleState<F, Adder>> = Vec::with_capacity(64);
202
203        (0..16).for_each(|i| {
204            debug_assert!(msg_vars[i].get_num_vars() == 5);
205            state.push(MessageScheduleState {
206                state_node: msg_vars[i].clone(),
207                state_adders: None,
208            })
209        });
210
211        for t in 16..64 {
212            assert!(state.len() >= t - 16);
213            let small_sigma_1_val = SmallSigma1::new(ckt_builder, &state[t - 2].state_node);
214            let w_first = state[t - 7].state_node.clone();
215
216            let small_sigma_0_val = SmallSigma0::new(ckt_builder, &state[t - 15].state_node);
217            let w_second = state[t - 16].state_node.clone();
218
219            let sum_a_leaf = Adder::layout_adder_circuit(
220                ckt_builder,
221                &small_sigma_1_val.get_output(),
222                &w_first,
223                carry_layer.cloned(),
224            );
225
226            let sum_b_leaf = Adder::layout_adder_circuit(
227                ckt_builder,
228                &small_sigma_0_val.get_output(),
229                &w_second,
230                carry_layer.cloned(),
231            );
232
233            let sum_a_b = Adder::layout_adder_circuit(
234                ckt_builder,
235                &sum_a_leaf.get_output(),
236                &sum_b_leaf.get_output(),
237                carry_layer.cloned(),
238            );
239
240            let current_state = MessageScheduleState {
241                state_node: sum_a_b.get_output(),
242                state_adders: Some(MessageScheduleAdderTree {
243                    sum_a_leaf,
244                    sum_b_leaf,
245                    sum_a_b,
246                }),
247            };
248
249            state.push(current_state);
250        }
251
252        Self {
253            msg_schedule: state,
254            _phantom: Default::default(),
255        }
256    }
257
258    /// Returns the list of 64 nodes corresponding to SHA256 message schedule
259    pub fn get_output_nodes(&self) -> Vec<NodeRef<F>> {
260        self.msg_schedule
261            .iter()
262            .map(|st| st.state_node.clone())
263            .collect()
264    }
265
266    /// Attaches a message schedule to the SHA Circuit. Note that this
267    /// needs to match the way adder is implemented
268    pub fn populate_message_schedule(
269        &self,
270        circuit: &mut Circuit<F>,
271        input_data: &[u32],
272    ) -> Vec<u32> {
273        debug_assert_eq!(input_data.len(), 16);
274        let mut state: Vec<u32> = Vec::with_capacity(64);
275
276        (0..16).for_each(|i| state.push(input_data[i]));
277
278        for t in 16..64 {
279            assert!(state.len() >= t - 16);
280            let small_sigma_1_val = SmallSigma1::<F>::evaluate(state[t - 2]);
281            let w_first = state[t - 7];
282
283            let small_sigma_0_val = SmallSigma0::<F>::evaluate(state[t - 15]);
284            let w_second = state[t - 16];
285
286            let add_tree = self.msg_schedule[t].state_adders.clone().unwrap();
287
288            let sum_a_value =
289                add_tree
290                    .sum_a_leaf
291                    .perform_addition(circuit, small_sigma_1_val, w_first);
292
293            let sum_b_value =
294                add_tree
295                    .sum_b_leaf
296                    .perform_addition(circuit, small_sigma_0_val, w_second);
297
298            let sum_a_b_value =
299                add_tree
300                    .sum_a_b
301                    .perform_addition(circuit, sum_a_value, sum_b_value);
302
303            debug_assert_eq!(
304                sum_a_b_value,
305                small_sigma_1_val
306                    .wrapping_add(w_first)
307                    .wrapping_add(small_sigma_0_val)
308                    .wrapping_add(w_second)
309            );
310            state.push(sum_a_b_value);
311        }
312
313        state
314    }
315
316    /// Creates the output expression that can be tested with other
317    /// output expressions
318    pub fn get_output_expr(&self) -> AbstractExpression<F> {
319        AbstractExpression::<F>::binary_tree_selector(
320            self.msg_schedule
321                .iter()
322                .map(|st| st.state_node.expr())
323                .collect(),
324        )
325    }
326}
327
328/// The Constant KeySchedule
329pub struct KeySchedule<F: Field> {
330    keys: Vec<ConstInputGate<F>>,
331}
332
333/// The Initial IV of SHA
334pub struct HConstants<F: Field> {
335    ivs: Vec<ConstInputGate<F>>,
336}
337
338impl<F: Field> HConstants<F> {
339    const H: [u32; 8] = [
340        0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab,
341        0x5be0cd19,
342    ];
343
344    fn iv_name(index: usize) -> String {
345        format!("sha256-iv-{index:x}")
346    }
347
348    /// Create the input gate to SHA entire message
349    pub fn new(ckt_builder: &mut CircuitBuilder<F>) -> Self {
350        Self {
351            ivs: Self::H
352                .iter()
353                .enumerate()
354                .map(|(i, val)| ConstInputGate::new(ckt_builder, &Self::iv_name(i), *val))
355                .collect(),
356        }
357    }
358
359    /// Populate the circuit with the IV values
360    pub fn populate_iv(&self, circuit: &mut Circuit<F>) {
361        for (ndx, const_iv) in self.ivs.iter().enumerate() {
362            circuit.set_input(&Self::iv_name(ndx), const_iv.input_mle().clone());
363        }
364    }
365
366    /// Returns the list of H constants as wires in the circuit
367    pub fn get_output_nodes(&self) -> Vec<NodeRef<F>> {
368        self.ivs.iter().map(|v| v.get_output()).collect()
369    }
370}
371
372impl<F: Field> KeySchedule<F> {
373    const ROUND_KEYS: [u32; 64] = [
374        0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4,
375        0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe,
376        0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f,
377        0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
378        0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc,
379        0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b,
380        0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, 0x19a4c116,
381        0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
382        0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7,
383        0xc67178f2,
384    ];
385
386    fn key_name(index: usize) -> String {
387        format!("sha256-key-schedule-{index}")
388    }
389
390    /// Create the SHA-256 key schedule
391    pub fn new(ckt_builder: &mut CircuitBuilder<F>) -> Self {
392        Self {
393            keys: Self::ROUND_KEYS
394                .iter()
395                .enumerate()
396                .map(|(i, val)| ConstInputGate::new(ckt_builder, &Self::key_name(i), *val))
397                .collect(),
398        }
399    }
400
401    /// the Key schedule to circuit
402    pub fn populate_key_schedule(&self, circuit: &mut Circuit<F>) {
403        for (i, key_const) in self.keys.iter().enumerate() {
404            circuit.set_input(&Self::key_name(i), key_const.input_mle().clone());
405        }
406    }
407}
408
409impl<F: Field> Index<usize> for KeySchedule<F> {
410    type Output = NodeRef<F>;
411
412    fn index(&self, index: usize) -> &Self::Output {
413        self.keys[index].get_output_ref()
414    }
415}
416
417struct CompressionFnRoundCarries<Adder> {
418    t1_carries: [Adder; 4], // Four '+' operations: T1 := h + Sigma1(e) + ch(e,f,g) + K_t + W_t
419    t2_carries: Adder,      // One '+' operation: T2 := Sigma0(a) + Maj(a,b,c)
420    e_carries: Adder,       // One '+' operation: e := e + T1
421    a_carries: Adder,       // One '+' operation: a := T1 + T2
422}
423
424/// Computes the single round of compression function
425pub struct CompressionFn<F, Adder> {
426    output: Vec<Adder>,
427    round_carries: Vec<CompressionFnRoundCarries<Adder>>,
428    _phantom: PhantomData<F>,
429}
430
431impl<F, Adder> CompressionFn<F, Adder>
432where
433    F: Field,
434    Adder: AdderGateTrait<F, IntegralType = u32> + Clone,
435{
436    /// A Single Round of SHA-256 compression function. The
437    /// `msg_schedule` is the 64 rounds of expanded message schedule.
438    /// The `input_schedule` is the 256-bits of Input values
439    /// `round_keys` are
440    pub fn new(
441        ckt_builder: &mut CircuitBuilder<F>,
442        carry_layer: Option<&InputLayerNodeRef<F>>,
443        msg_schedule: &MessageSchedule<F, Adder>, // Expanded message schedule
444        input_schedule: &[NodeRef<F>],            // IV for fist message
445        round_keys: &KeySchedule<F>,              // Key Schedule
446    ) -> Self {
447        let msg_schedule = msg_schedule.get_output_nodes();
448        debug_assert_eq!(msg_schedule.len(), 64);
449        debug_assert_eq!(input_schedule.len(), 8);
450        (0..64).for_each(|i| debug_assert_eq!(msg_schedule[i].get_num_vars(), 5));
451        (0..8).for_each(|i| debug_assert_eq!(input_schedule[i].get_num_vars(), 5));
452
453        let mut round_carries: Vec<CompressionFnRoundCarries<Adder>> = Vec::with_capacity(64);
454
455        let mut a = input_schedule[0].clone();
456        let mut b = input_schedule[1].clone();
457        let mut c = input_schedule[2].clone();
458        let mut d = input_schedule[3].clone();
459        let mut e = input_schedule[4].clone();
460        let mut f = input_schedule[5].clone();
461        let mut g = input_schedule[6].clone();
462        let mut h = input_schedule[7].clone();
463
464        for t in 0..64 {
465            let w_t = &msg_schedule[t];
466            let k_t = &round_keys[t];
467            let t1 = Self::compute_t1(ckt_builder, carry_layer, &e, &f, &g, &h, w_t, k_t);
468            let t2 = Self::compute_t2(ckt_builder, carry_layer, &a, &b, &c);
469            h = ckt_builder.add_sector(g.expr());
470            g = ckt_builder.add_sector(f.expr());
471            f = ckt_builder.add_sector(e.expr());
472            let e_sum = Adder::layout_adder_circuit(
473                ckt_builder,
474                &d,
475                &t1.last().unwrap().get_output(),
476                carry_layer.cloned(),
477            );
478            e = e_sum.get_output();
479            d = ckt_builder.add_sector(c.expr());
480            c = ckt_builder.add_sector(b.expr());
481            b = ckt_builder.add_sector(a.expr());
482            let a_sum = Adder::layout_adder_circuit(
483                ckt_builder,
484                &t1.last().unwrap().get_output(),
485                &t2.get_output(),
486                carry_layer.cloned(),
487            );
488            a = a_sum.get_output();
489            round_carries.push(CompressionFnRoundCarries {
490                t1_carries: t1,
491                t2_carries: t2,
492                e_carries: e_sum,
493                a_carries: a_sum,
494            });
495        }
496
497        let intermediates = [a, b, c, d, e, f, g, h];
498
499        let output = input_schedule
500            .iter()
501            .zip(intermediates.iter())
502            .map(|(h, x)| Adder::layout_adder_circuit(ckt_builder, h, x, carry_layer.cloned()))
503            .collect();
504
505        Self {
506            output,
507            round_carries,
508            _phantom: Default::default(),
509        }
510    }
511
512    #[allow(clippy::too_many_arguments)]
513    fn compute_t1(
514        ckt_builder: &mut CircuitBuilder<F>,
515        carry_layer: Option<&InputLayerNodeRef<F>>,
516        e: &NodeRef<F>,
517        f: &NodeRef<F>,
518        g: &NodeRef<F>,
519        h: &NodeRef<F>,
520        w_t: &NodeRef<F>,
521        k_t: &NodeRef<F>,
522    ) -> [Adder; 4] {
523        debug_assert!(e.get_num_vars() == 5);
524        debug_assert!(f.get_num_vars() == 5);
525        debug_assert!(g.get_num_vars() == 5);
526        debug_assert!(h.get_num_vars() == 5);
527        debug_assert!(w_t.get_num_vars() == 5);
528        debug_assert!(k_t.get_num_vars() == 5);
529
530        let t1_sigma_1 = Sigma1::new(ckt_builder, e);
531        let t1_ch = ChGate::new(ckt_builder, e, f, g);
532
533        // h1 + Sigma1(e)
534        let sum1 = Adder::layout_adder_circuit(
535            ckt_builder,
536            h,
537            &t1_sigma_1.get_output(),
538            carry_layer.cloned(),
539        );
540
541        // ch(e,f,g) + K_t
542        let sum2 = Adder::layout_adder_circuit(
543            ckt_builder,
544            &t1_ch.get_output(),
545            k_t,
546            carry_layer.cloned(),
547        );
548
549        // h1 + Sigma1(e) + ch(e,f,g) + K_t
550        let sum3 = Adder::layout_adder_circuit(
551            ckt_builder,
552            &sum1.get_output(),
553            &sum2.get_output(),
554            carry_layer.cloned(),
555        );
556
557        // h1 + Sigma1(e) + ch(e,f,g) + K_t + W_t
558        let sum4 =
559            Adder::layout_adder_circuit(ckt_builder, &sum3.get_output(), w_t, carry_layer.cloned());
560
561        [sum1, sum2, sum3, sum4]
562    }
563
564    fn compute_t2(
565        ckt_builder: &mut CircuitBuilder<F>,
566        carry_layer: Option<&InputLayerNodeRef<F>>,
567        a: &NodeRef<F>,
568        b: &NodeRef<F>,
569        c: &NodeRef<F>,
570    ) -> Adder {
571        let s1 = Sigma0::new(ckt_builder, a);
572        let m1 = MajGate::new(ckt_builder, a, b, c);
573        Adder::layout_adder_circuit(
574            ckt_builder,
575            &s1.get_output(),
576            &m1.get_output(),
577            carry_layer.cloned(),
578        )
579    }
580
581    // #[cfg(debug_assertions)]
582    // fn print_state(msg: String, state: &[u32]) {
583    //     println!("{}", msg);
584    //     println!(
585    //         "{}",
586    //         state
587    //             .iter()
588    //             .map(|v| format!("  0x{:08x}", v))
589    //             .collect::<Vec<_>>()
590    //             .join("\n")
591    //     );
592    // }
593
594    /// Populated the carry bits of the adder. This function must match
595    /// the addition operations exactly as during the circuit building.
596    pub fn populate_compression_fn(
597        &self,
598        circuit: &mut Circuit<F>,
599        message_words: Vec<u32>,
600        input_words: Vec<u32>,
601    ) -> Vec<u32> {
602        let mut a = input_words[0];
603        let mut b = input_words[1];
604        let mut c = input_words[2];
605        let mut d = input_words[3];
606        let mut e = input_words[4];
607        let mut f = input_words[5];
608        let mut g = input_words[6];
609        let mut h = input_words[7];
610
611        // #[cfg(debug_assertions)]
612        // Self::print_state(
613        //     "========= Message Schedule =========".to_string(),
614        //     &message_words,
615        // );
616
617        // #[cfg(debug_assertions)]
618        // Self::print_state(
619        //     "========= Initial State =========".to_string(),
620        //     &input_words,
621        // );
622
623        for (t, w_t) in message_words.iter().enumerate().take(64) {
624            let k_t = KeySchedule::<F>::ROUND_KEYS[t];
625            let [sum1, sum2, sum3, sum4] = &self.round_carries[t].t1_carries;
626            let t1 = Self::populate_t1(circuit, sum1, sum2, sum3, sum4, e, f, g, h, *w_t, k_t);
627            let t2 = Self::populate_t2(circuit, &self.round_carries[t].t2_carries, a, b, c);
628            h = g;
629            g = f;
630            f = e;
631            e = self.round_carries[t]
632                .e_carries
633                .perform_addition(circuit, d, t1);
634            d = c;
635            c = b;
636            b = a;
637            a = self.round_carries[t]
638                .a_carries
639                .perform_addition(circuit, t1, t2);
640
641            // #[cfg(debug_assertions)]
642            // Self::print_state(
643            //     format!("--------> {t} <---------"),
644            //     &[a, b, c, d, e, f, g, h],
645            // );
646        }
647
648        let intermediates = [a, b, c, d, e, f, g, h];
649
650        input_words
651            .iter()
652            .zip(intermediates.iter())
653            .zip(&self.output)
654            .map(|((h, hprime), gate)| gate.perform_addition(circuit, *h, *hprime))
655            .collect()
656    }
657
658    #[allow(clippy::too_many_arguments)]
659    fn populate_t1(
660        circuit: &mut Circuit<F>,
661        sum1: &Adder,
662        sum2: &Adder,
663        sum3: &Adder,
664        sum4: &Adder,
665        e: u32,
666        f: u32,
667        g: u32,
668        h: u32,
669        w_t: u32,
670        k_t: u32,
671    ) -> u32 {
672        let t1_sigma_1 = Sigma1::<F>::evaluate(e);
673        let t1_ch = ChGate::<F>::evaluate(e, f, g);
674
675        // h1 + Sigma1(e)
676        let sum1 = sum1.perform_addition(circuit, h, t1_sigma_1);
677
678        // ch(e,f,g) + K_t
679        let sum2 = sum2.perform_addition(circuit, t1_ch, k_t);
680
681        // h1 + Sigma1(e) + ch(e,f,g) + K_t
682        let sum3 = sum3.perform_addition(circuit, sum1, sum2);
683
684        // h1 + Sigma1(e) + ch(e,f,g) + K_t + W_t
685        sum4.perform_addition(circuit, sum3, w_t)
686    }
687
688    fn populate_t2(circuit: &mut Circuit<F>, sum1: &Adder, a: u32, b: u32, c: u32) -> u32 {
689        let s1 = Sigma0::<F>::evaluate(a);
690        let m1 = MajGate::<F>::evaluate(a, b, c);
691        sum1.perform_addition(circuit, s1, m1)
692    }
693
694    /// Returns 8 32-bit words (256-bits) output of the compression
695    /// function
696    pub fn get_output_nodes(&self) -> Vec<NodeRef<F>> {
697        self.output.iter().map(|n| n.get_output()).collect()
698    }
699}
700
701fn sha256_padded_input(mut input_data: Vec<u8>) -> Vec<u32> {
702    let input_len_bits = input_data.len() * 8;
703    let pad_bits = 448 - ((input_len_bits + 1) % 512);
704    let zero_bytes = pad_bits / 8;
705    input_data.push(0x80);
706    input_data.extend(std::iter::repeat_n(0_u8, zero_bytes));
707    input_data.extend_from_slice(input_len_bits.to_be_bytes().as_slice());
708    assert!(input_data.len().is_multiple_of(64));
709
710    input_data
711        .as_slice()
712        .chunks_exact(4)
713        .map(|chunk| chunk.iter().fold(0u32, |acc, v| (acc << 8) + (*v as u32)))
714        .collect()
715}
716
717struct Sha256State<F: Field, Adder> {
718    message_schedule: MessageSchedule<F, Adder>, // Expanded message schedule
719    compression_fn: CompressionFn<F, Adder>,
720    input_chunks: Vec<u32>, // Input data chunked into 32-bit words
721}
722
723/// Sha256 State for multi word computation
724pub struct Sha256<F: Field, Adder> {
725    key_schedule: KeySchedule<F>,
726    init_iv: HConstants<F>,
727    round_states: Vec<Sha256State<F, Adder>>,
728}
729
730impl<F, Adder> Sha256<F, Adder>
731where
732    F: Field,
733    Adder: AdderGateTrait<F, IntegralType = u32> + Clone,
734{
735    /// Creates a new SHA256 circuit given arbitrary length data input_data
736    pub fn new(
737        ckt_builder: &mut CircuitBuilder<F>,
738        data_input_layer: &InputLayerNodeRef<F>,
739        carry_layer: Option<&InputLayerNodeRef<F>>,
740        input_data: Vec<u8>,
741    ) -> Self {
742        let key_schedule = KeySchedule::<F>::new(ckt_builder); // 32*64-bit
743        let init_iv = HConstants::<F>::new(ckt_builder); // 256-bit en
744        let input_data = sha256_padded_input(input_data);
745        let num_vars = input_data.len().ilog2() as usize + 5; // = log(32-bits * input_data.len())
746        let all_input = ckt_builder.add_input_shred("SHA256_input", num_vars, data_input_layer);
747
748        let b = &all_input;
749        let b_sq = AbstractExpression::products(vec![b.id(), b.id()]);
750        let b = b.expr();
751
752        // Check that all input bits are binary.
753        let binary_sector = ckt_builder.add_sector(
754            // b * (1 - b) = b - b^2
755            b - b_sq,
756        );
757
758        // Make sure all inputs are either `0` or `1`
759        ckt_builder.set_output(&binary_sector);
760
761        let mut input_schedule = init_iv.get_output_nodes();
762
763        let input_word_splits =
764            ckt_builder.add_split_node(&all_input, input_data.len().ilog2() as usize);
765
766        debug_assert_eq!(input_word_splits.len(), input_data.len());
767
768        let round_states = input_word_splits
769            .as_slice()
770            .chunks_exact(16)
771            .zip(input_data.as_slice().chunks_exact(16))
772            .map(|(msg_vars, data)| {
773                let message_schedule = MessageSchedule::new(ckt_builder, carry_layer, msg_vars);
774                let ckt = CompressionFn::new(
775                    ckt_builder,
776                    carry_layer,
777                    &message_schedule,
778                    &input_schedule,
779                    &key_schedule,
780                );
781                input_schedule = ckt.get_output_nodes();
782
783                Sha256State {
784                    message_schedule,
785                    input_chunks: data.to_vec(),
786                    compression_fn: ckt,
787                }
788            })
789            .collect_vec();
790        Self {
791            key_schedule,
792            round_states,
793            init_iv,
794        }
795    }
796
797    /// Returns the input message padded according to SHA256 spec
798    pub fn padded_data_chunks(&self) -> Vec<u32> {
799        self.round_states
800            .iter()
801            .flat_map(|st| st.input_chunks.clone())
802            .collect()
803    }
804
805    /// Returns the output node for the last Round of SHA256
806    pub fn get_output_node(&self) -> Vec<NodeRef<F>> {
807        self.round_states
808            .last()
809            .map(|v| v.compression_fn.get_output_nodes())
810            .unwrap()
811    }
812
813    /// Populates the state of each SHA Round
814    fn populate_state(
815        &self,
816        circuit: &mut Circuit<F>,
817        state: &Sha256State<F, Adder>,
818        input_words: Vec<u32>,
819    ) -> Vec<u32> {
820        let message_words = state
821            .message_schedule
822            .populate_message_schedule(circuit, &state.input_chunks);
823        state
824            .compression_fn
825            .populate_compression_fn(circuit, message_words, input_words)
826    }
827
828    /// Populates the circuit with the required data that was passed
829    /// during `new`. Returns the output of SHA256 hash
830    pub fn populate_circuit(&self, circuit: &mut Circuit<F>) -> Vec<u32> {
831        let all_bits = self
832            .round_states
833            .iter()
834            .flat_map(|st| st.input_chunks.iter().map(|v| bit_decompose_msb_first(*v)))
835            .flatten()
836            .map(u64::from)
837            .map(F::from)
838            .collect_vec();
839
840        self.key_schedule.populate_key_schedule(circuit);
841        self.init_iv.populate_iv(circuit);
842
843        let final_result = self
844            .round_states
845            .iter()
846            .fold(HConstants::<F>::H.to_vec(), |hash_val, st| {
847                self.populate_state(circuit, st, hash_val)
848            });
849
850        let input_mle = MultilinearExtension::new(all_bits);
851        circuit.set_input("SHA256_input", input_mle);
852        final_result
853    }
854}
855
856#[cfg(test)]
857mod tests {
858    #[test]
859    fn test_bit_decomp_carry() {
860        let x = 0xffffffffu32;
861        let y = 0x1u32;
862        let z = 0x80000000u32;
863
864        #[rustfmt::skip]
865        let carry_x_plus_y = vec![
866            1, 1, 1, 1, 1, 1, 1, 1,
867            1, 1, 1, 1, 1, 1, 1, 1,
868            1, 1, 1, 1, 1, 1, 1, 1,
869            1, 1, 1, 1, 1, 1, 1, 1,
870        ];
871
872        let (s, c) = super::add_get_carry_bits_lsb(x, y, 0);
873        assert_eq!(s, x.wrapping_add(y));
874        assert_eq!(c, carry_x_plus_y);
875
876        let (s, c) = super::add_get_carry_bits_lsb(x, y, 1);
877        assert_eq!(s, x.wrapping_add(y).wrapping_add(1));
878        assert_eq!(c, carry_x_plus_y);
879
880        #[rustfmt::skip]
881        let carry_x_plus_z = vec![
882            0, 0, 0, 0, 0, 0, 0, 0,
883            0, 0, 0, 0, 0, 0, 0, 0,
884            0, 0, 0, 0, 0, 0, 0, 0,
885            0, 0, 0, 0, 0, 0, 0, 1
886        ];
887
888        let (s, c) = super::add_get_carry_bits_lsb(x, z, 0);
889        assert_eq!(s, x.wrapping_add(z));
890        assert_eq!(c, carry_x_plus_z);
891
892        #[rustfmt::skip]
893        let carry_x_plus_z = vec![
894            1, 1, 1, 1, 1, 1, 1, 1,
895            1, 1, 1, 1, 1, 1, 1, 1,
896            1, 1, 1, 1, 1, 1, 1, 1,
897            1, 1, 1, 1, 1, 1, 1, 1,
898        ];
899
900        let (s, c) = super::add_get_carry_bits_lsb(x, z, 1);
901        assert_eq!(s, x.wrapping_add(z).wrapping_add(1));
902        assert_eq!(c, carry_x_plus_z);
903    }
904}