frontend/components/sha2_gkr/
ripple_carry_adder.rs

1use super::AdderGateTrait;
2use crate::{
3    abstract_expr::AbstractExpression,
4    layouter::builder::{CircuitBuilder, InputLayerNodeRef, NodeRef},
5};
6use itertools::Itertools;
7use shared_types::Field;
8use std::marker::PhantomData;
9
10/// A single bit full adder
11#[derive(Clone, Debug)]
12pub struct FullAdder<F: Field> {
13    s: NodeRef<F>,
14    c: NodeRef<F>,
15}
16
17impl<F: Field> FullAdder<F> {
18    /// Build a full adder circuit where `x` is the first input wire,
19    /// `y` is the second input wire and `carry` is the input carry.
20    pub fn new(
21        builder_ref: &mut CircuitBuilder<F>,
22        x: &NodeRef<F>,
23        y: &NodeRef<F>,
24        carry: &NodeRef<F>,
25    ) -> Self {
26        debug_assert!(x.get_num_vars() == 0);
27        debug_assert!(y.get_num_vars() == 0);
28        debug_assert!(carry.get_num_vars() == 0);
29
30        // ab = a `and` b
31        let ab = x.clone().expr() * y.clone().expr();
32
33        // xor_ab1 = a `xor` b. `xor` output is normalized, i.e., its
34        // output is guaranteed to be in {0,1} if the input is in {0,1}.
35        // In the multiplicative basis where 0 -> 1, and 1 -> -1 (i.e.,
36        // the homomorphism x --> (-1)^x ), xor becomes a single
37        // multiplication, but such optimizations are outside the scope
38        // of this code.
39        let xor_ab1 = x.clone().expr() + y.clone().expr()
40            - AbstractExpression::constant(F::from(2)) * x.clone().expr() * y.clone().expr();
41
42        // sum = a `xor` b `xor` carry = ab1 `xor` carry
43        let bit_sum = builder_ref.add_sector(
44            xor_ab1.clone() + carry.clone().expr()
45                - AbstractExpression::constant(F::from(2)) * xor_ab1.clone() * carry.clone().expr(),
46        );
47
48        // cin_x_xor1  = carry & xor_ab1. Output is naturally normalized
49        let cin_x_xor1 = carry.expr() * xor_ab1.clone();
50
51        // Carry out = cin_x_xor1 | ab. Output is normalized.
52        let carry_out = builder_ref.add_sector(ab.clone() + cin_x_xor1.clone() - ab * cin_x_xor1);
53
54        Self {
55            s: bit_sum,
56            c: carry_out,
57        }
58    }
59
60    /// Return the output and carry of full adder
61    pub fn get_output(&self) -> (/* sum */ NodeRef<F>, /* carry */ NodeRef<F>) {
62        (self.s.clone(), self.c.clone())
63    }
64}
65
66/// mod 2^BITWIDTH adder with no input carry and no output carry
67#[derive(Clone)]
68pub struct RippleCarryAdderMod2w<const BITWIDTH: usize, F: Field> {
69    sum_node: NodeRef<F>,
70    _phantom: PhantomData<F>,
71}
72
73impl<F: Field> AdderGateTrait<F> for RippleCarryAdderMod2w<32, F> {
74    type IntegralType = u32;
75
76    fn layout_adder_circuit(
77        circuit_builder: &mut CircuitBuilder<F>, // Circuit builder
78        x_node: &NodeRef<F>,                     // reference to x in x + y
79        y_node: &NodeRef<F>,                     // reference to y in x + y
80        _: Option<InputLayerNodeRef<F>>,         // Carry Layer information
81    ) -> Self {
82        Self::new(circuit_builder, x_node, y_node)
83    }
84
85    /// Returns the output of AdderNoCarry
86    fn get_output(&self) -> NodeRef<F> {
87        self.sum_node.clone()
88    }
89
90    fn perform_addition(
91        &self,
92        _circuit: &mut crate::layouter::builder::Circuit<F>,
93        x: u32,
94        y: u32,
95    ) -> u32 {
96        x.wrapping_add(y)
97    }
98}
99
100impl<const BITWIDTH: usize, F> RippleCarryAdderMod2w<BITWIDTH, F>
101where
102    F: Field,
103{
104    /// Creates a BITWIDTH word Integer adder. For SHA-256/224 BITWIDTH is 32, for SHA-512/384, BITWIDTH is 64
105    ///
106    /// `x_word` and `y_word` are assumed to be MSB-first decomposition of
107    /// of the data.
108    pub fn new(
109        builder_ref: &mut CircuitBuilder<F>,
110        x_word: &NodeRef<F>,
111        y_word: &NodeRef<F>,
112    ) -> Self {
113        debug_assert!(BITWIDTH.is_power_of_two());
114        debug_assert!(x_word.get_num_vars() == BITWIDTH.ilog2() as usize);
115        debug_assert!(y_word.get_num_vars() == BITWIDTH.ilog2() as usize);
116        let num_vars = BITWIDTH.ilog2() as usize;
117        let x_word_wires = builder_ref.add_split_node(x_word, num_vars);
118        let y_word_wires = builder_ref.add_split_node(y_word, num_vars);
119        let mut c_in = builder_ref.add_sector(AbstractExpression::constant(F::from(0)));
120
121        let mut sum_expr = Vec::<NodeRef<F>>::with_capacity(BITWIDTH);
122
123        for (x, y) in x_word_wires.iter().zip(y_word_wires.iter()).rev() {
124            // Wires are in MSB first, hence the .rev()
125            let fa = FullAdder::new(builder_ref, x, y, &c_in);
126            let (s, c) = fa.get_output();
127            c_in = c;
128            sum_expr.push(s);
129        }
130
131        // Make sure that c_out of the last round is either 0 or 1
132        let c_out_expr = c_in.expr();
133        let one_or_zero =
134            c_out_expr.clone() * (AbstractExpression::constant(F::from(1)) - c_out_expr);
135        let carry_sector = builder_ref.add_sector(one_or_zero);
136        builder_ref.set_output(&carry_sector);
137
138        // Swap to MSB first hence the rev()
139        let sum_rewired = sum_expr.iter().rev().map(|n| n.expr()).collect_vec();
140
141        let sum_node =
142            builder_ref.add_sector(AbstractExpression::<F>::binary_tree_selector(sum_rewired));
143
144        Self {
145            sum_node,
146            _phantom: Default::default(),
147        }
148    }
149}