frontend/components/sha2_gkr/
brent_kung_adder.rs

1//! Brent-Kung adder is a parallel prefix adder described here <https://maths-people.anu.edu.au/~brent/pd/rpb060_IEEETC.pdf>
2//!
3//! For a comparison of different adders, see also <https://www.lirmm.fr/arith18/papers/patil-RobustEnergyEffcientAdder.pdf>
4
5use super::AdderGateTrait;
6use crate::{
7    abstract_expr::AbstractExpression,
8    layouter::builder::{CircuitBuilder, InputLayerNodeRef, NodeRef},
9};
10use shared_types::Field;
11use std::marker::PhantomData;
12
13#[inline(always)]
14fn mul<F: Field>(a: AbstractExpression<F>, b: AbstractExpression<F>) -> AbstractExpression<F> {
15    a * b
16}
17
18#[inline(always)]
19fn xor<F: Field>(a: AbstractExpression<F>, b: AbstractExpression<F>) -> AbstractExpression<F> {
20    a.clone() + b.clone() - AbstractExpression::constant(F::from(2)) * a * b
21}
22
23// #[inline(always)]
24// fn or<F: Field>(a: AbstractExpression<F>, b: AbstractExpression<F>) -> AbstractExpression<F> {
25//     a.clone() + b.clone() - a * b
26// }
27
28// Input values in MSB format and compute the 4-bit parallel prefix
29// adder
30fn pp_adder_4_bit<F: Field>(
31    builder_ref: &mut CircuitBuilder<F>,
32    x_val: NodeRef<F>,
33    y_val: NodeRef<F>,
34    carry_in: Option<NodeRef<F>>,
35) -> (NodeRef<F> /* sum */, NodeRef<F> /*carry  */) {
36    debug_assert!(x_val.get_num_vars() == 2);
37    debug_assert!(y_val.get_num_vars() == 2);
38    debug_assert!(carry_in
39        .as_ref()
40        .map(|v| v.get_num_vars() == 0)
41        .unwrap_or(true));
42
43    let propagate = builder_ref.add_sector(
44        x_val.expr() + y_val.expr()
45            - AbstractExpression::constant(F::from(2)) * x_val.expr() * y_val.expr(),
46    );
47    let generate = builder_ref.add_sector(x_val.expr() * y_val.expr());
48    let mut p = builder_ref.add_split_node(&propagate, 2);
49    let mut g = builder_ref.add_split_node(&generate, 2);
50
51    // Convert to LSB first mode
52    p.reverse();
53    g.reverse();
54
55    assert!(p.len() == 4);
56    assert!(g.len() == 4);
57
58    // Step 2: Prefix computation
59    let p1g0 = mul(p[1].expr(), g[0].expr());
60    let p0p1 = mul(p[0].expr(), p[1].expr());
61    let p2p3 = mul(p[2].expr(), p[3].expr());
62
63    let g10 = xor(g[1].expr(), p1g0.clone());
64    let g20 = mul(p[2].expr(), g10.clone());
65    let g20 = xor(g[2].expr(), g20.clone());
66    let g30 = mul(p[3].expr(), g20.clone());
67    let g30 = xor(g[3].expr(), g30.clone());
68
69    // Step 3: Calculate carries
70    let c0 = carry_in
71        .map(|v| v.expr())
72        .unwrap_or(AbstractExpression::constant(F::ZERO));
73    let tmp = mul(p[0].expr(), c0.clone());
74    let c1 = xor(g[0].expr(), tmp);
75    let tmp = mul(p0p1.clone(), c0.clone());
76    let c2 = xor(g10.clone(), tmp);
77    let tmp = mul(p[2].expr(), c0.clone());
78    let tmp = mul(p0p1.clone(), tmp);
79    let c3 = xor(g20, tmp);
80    let tmp = mul(p0p1, p2p3);
81    let tmp = mul(tmp, c0.clone());
82    let c4 = xor(g30, tmp);
83
84    // Reversed bit order
85    let sum = vec![
86        xor(p[3].expr(), c3),
87        xor(p[2].expr(), c2),
88        xor(p[1].expr(), c1),
89        xor(p[0].expr(), c0),
90    ];
91
92    (
93        builder_ref.add_sector(AbstractExpression::binary_tree_selector(sum)),
94        builder_ref.add_sector(c4),
95    )
96}
97
98/// Brent-Kung Adder
99#[derive(Debug, Clone)]
100pub struct BKAdder<const BITWIDTH: usize, F: Field> {
101    sum_node: NodeRef<F>,
102    _phantom: PhantomData<F>,
103}
104
105impl<F: Field> AdderGateTrait<F> for BKAdder<32, F> {
106    type IntegralType = u32;
107
108    fn layout_adder_circuit(
109        circuit_builder: &mut CircuitBuilder<F>,   // Circuit builder
110        x_node: &NodeRef<F>,                       // reference to x in x + y
111        y_node: &NodeRef<F>,                       // reference to y in x + y
112        carry_layer: Option<InputLayerNodeRef<F>>, // Carry Layer information
113    ) -> Self {
114        Self::new(circuit_builder, x_node, y_node, carry_layer)
115    }
116
117    /// Returns the output of AdderNoCarry
118    fn get_output(&self) -> NodeRef<F> {
119        self.sum_node.clone()
120    }
121
122    fn perform_addition(
123        &self,
124        _circuit: &mut crate::layouter::builder::Circuit<F>,
125        x: u32,
126        y: u32,
127    ) -> u32 {
128        x.wrapping_add(y)
129    }
130}
131
132impl<const BITWIDTH: usize, F> BKAdder<BITWIDTH, F>
133where
134    F: Field,
135{
136    fn new(
137        builder_ref: &mut CircuitBuilder<F>,
138        x_word: &NodeRef<F>,
139        y_word: &NodeRef<F>,
140        _carry_layer: Option<InputLayerNodeRef<F>>,
141    ) -> Self {
142        assert!(
143            BITWIDTH.is_multiple_of(8),
144            "Only bitwidths of multiple of 8 are supported"
145        );
146
147        assert!(
148            x_word.get_num_vars() == BITWIDTH.ilog2() as usize,
149            "The number of variables must match"
150        );
151
152        let chunks = BITWIDTH.ilog2().saturating_sub(2) as usize;
153
154        assert!(chunks > 0);
155        let x_4bit_chunks = builder_ref.add_split_node(x_word, chunks);
156        let y_4bit_chunks = builder_ref.add_split_node(y_word, chunks);
157
158        let mut sum_chunks = Vec::<NodeRef<F>>::new();
159        let mut carry: Option<NodeRef<F>> = None;
160
161        x_4bit_chunks
162            .into_iter()
163            .rev()
164            .zip(y_4bit_chunks.into_iter().rev())
165            .for_each(|(x, y)| {
166                let (s, c) = pp_adder_4_bit(builder_ref, x, y, carry.clone());
167                sum_chunks.push(s);
168                carry = Some(c);
169            });
170
171        let final_sum = sum_chunks.into_iter().rev().map(|v| v.expr()).collect();
172
173        let final_carry = carry.map(|c| c.expr()).unwrap();
174
175        let final_carry_is_one_or_zero = builder_ref
176            .add_sector(final_carry.clone() * (AbstractExpression::constant(F::ONE) - final_carry));
177
178        builder_ref.set_output(&final_carry_is_one_or_zero);
179
180        Self {
181            sum_node: builder_ref.add_sector(AbstractExpression::binary_tree_selector(final_sum)),
182            _phantom: Default::default(),
183        }
184    }
185}