frontend/components/binary_operations/
binary_adder.rs

1//! Implements binary addition gates.
2
3use std::cmp::max;
4
5use shared_types::Field;
6
7use crate::{
8    abstract_expr::AbstractExpression,
9    components::binary_operations::logical_shift::ShiftNode,
10    layouter::builder::{CircuitBuilder, NodeRef},
11};
12// use remainder::expression::abstract_expr::ExprBuilder;
13
14/// Performs binary addition between two nodes that represent binary values, given the vector of
15/// carries as a witness. Works with bit-widths that are powers of 2 up to `2^30 = 1,073,741,824`
16/// bits (this constraint is inherited from `ShiftNode`).
17///
18/// # Requires
19/// All inputs are assumed to only contain binary digits (i.e. only values from the set
20/// `{F::ZERO, F::ONE}` for a field `F`).
21#[derive(Clone, Debug)]
22pub struct BinaryAdder<F: Field> {
23    adder_sector: NodeRef<F>,
24}
25
26impl<F: Field> BinaryAdder<F> {
27    /// Generates a new [BinaryAdder] adding the values in nodes `lhs_bits` and `rhs_bits`, given
28    /// the `carry_bits` as a witness.
29    pub fn new(
30        builder_ref: &mut CircuitBuilder<F>,
31        lhs_bits: &NodeRef<F>,
32        rhs_bits: &NodeRef<F>,
33        carry_bits: &NodeRef<F>,
34    ) -> Self {
35        let num_vars = max(
36            carry_bits.get_num_vars(),
37            max(lhs_bits.get_num_vars(), rhs_bits.get_num_vars()),
38        );
39
40        // Shift the carry bits by one to the left to align then so that they can
41        // be added along with the corresponding LHS and RHS bit using a full-adder circuit.
42        let shifted_carries = ShiftNode::new(builder_ref, num_vars, -1, carry_bits);
43
44        let b0 = lhs_bits;
45        let b1 = rhs_bits;
46        let c = shifted_carries.get_output();
47
48        let b0_c = AbstractExpression::products(vec![b0.id(), c.id()]);
49        let b1_c = AbstractExpression::products(vec![b1.id(), c.id()]);
50        let b0_b1 = AbstractExpression::products(vec![b0.id(), b1.id()]);
51        let b0_b1_c = AbstractExpression::products(vec![b0.id(), b1.id(), c.id()]);
52
53        let two_b0_c = AbstractExpression::scaled(b0_c, F::from(2));
54        let two_b1_c = AbstractExpression::scaled(b1_c, F::from(2));
55        let two_b0_b1 = AbstractExpression::scaled(b0_b1, F::from(2));
56        let four_b0_b1_c = AbstractExpression::scaled(b0_b1_c, F::from(4));
57
58        let full_adder_result_sector = builder_ref.add_sector(
59            // The following expression is equivalent to: `b0 XOR b1 XOR c`
60            b0.expr() + b1.expr() + c.expr() - two_b0_c - two_b1_c - two_b0_b1 + four_b0_b1_c,
61        );
62
63        let b0 = lhs_bits;
64        let b1 = rhs_bits;
65        let c = shifted_carries.get_output();
66        let expected_c = carry_bits;
67
68        let carry_check_sector = builder_ref.add_sector(
69            // The next carry is 1 iff at least 2 of the 3 input bits (`b0`, `b1` and `c`) are
70            // 1. The following expression is the multilinear polynomial extending the boolean
71            // function described in the previous sentence.
72            b0.expr() * b1.expr() * c.expr()
73                + b0.expr() * b1.expr() * -(c.expr() - F::ONE)
74                + b0.expr() * -(b1.expr() - F::ONE) * c.expr()
75                + -(b0.expr() - F::ONE) * b1.expr() * c.expr()
76                - expected_c.expr(),
77        );
78
79        builder_ref.set_output(&carry_check_sector);
80
81        Self {
82            adder_sector: full_adder_result_sector,
83        }
84    }
85
86    /// Returns a reference to the output of the adder circuit.
87    pub fn get_output(&self) -> NodeRef<F> {
88        self.adder_sector.clone()
89    }
90}