frontend/components/sha2_gkr/
ripple_carry_adder.rs1use super::AdderGateTrait;
2use crate::{
3 abstract_expr::AbstractExpression,
4 layouter::builder::{CircuitBuilder, InputLayerNodeRef, NodeRef},
5};
6use itertools::Itertools;
7use shared_types::Field;
8use std::marker::PhantomData;
9
10#[derive(Clone, Debug)]
12pub struct FullAdder<F: Field> {
13 s: NodeRef<F>,
14 c: NodeRef<F>,
15}
16
17impl<F: Field> FullAdder<F> {
18 pub fn new(
21 builder_ref: &mut CircuitBuilder<F>,
22 x: &NodeRef<F>,
23 y: &NodeRef<F>,
24 carry: &NodeRef<F>,
25 ) -> Self {
26 debug_assert!(x.get_num_vars() == 0);
27 debug_assert!(y.get_num_vars() == 0);
28 debug_assert!(carry.get_num_vars() == 0);
29
30 let ab = x.clone().expr() * y.clone().expr();
32
33 let xor_ab1 = x.clone().expr() + y.clone().expr()
40 - AbstractExpression::constant(F::from(2)) * x.clone().expr() * y.clone().expr();
41
42 let bit_sum = builder_ref.add_sector(
44 xor_ab1.clone() + carry.clone().expr()
45 - AbstractExpression::constant(F::from(2)) * xor_ab1.clone() * carry.clone().expr(),
46 );
47
48 let cin_x_xor1 = carry.expr() * xor_ab1.clone();
50
51 let carry_out = builder_ref.add_sector(ab.clone() + cin_x_xor1.clone() - ab * cin_x_xor1);
53
54 Self {
55 s: bit_sum,
56 c: carry_out,
57 }
58 }
59
60 pub fn get_output(&self) -> (NodeRef<F>, NodeRef<F>) {
62 (self.s.clone(), self.c.clone())
63 }
64}
65
66#[derive(Clone)]
68pub struct RippleCarryAdderMod2w<const BITWIDTH: usize, F: Field> {
69 sum_node: NodeRef<F>,
70 _phantom: PhantomData<F>,
71}
72
73impl<F: Field> AdderGateTrait<F> for RippleCarryAdderMod2w<32, F> {
74 type IntegralType = u32;
75
76 fn layout_adder_circuit(
77 circuit_builder: &mut CircuitBuilder<F>, x_node: &NodeRef<F>, y_node: &NodeRef<F>, _: Option<InputLayerNodeRef<F>>, ) -> Self {
82 Self::new(circuit_builder, x_node, y_node)
83 }
84
85 fn get_output(&self) -> NodeRef<F> {
87 self.sum_node.clone()
88 }
89
90 fn perform_addition(
91 &self,
92 _circuit: &mut crate::layouter::builder::Circuit<F>,
93 x: u32,
94 y: u32,
95 ) -> u32 {
96 x.wrapping_add(y)
97 }
98}
99
100impl<const BITWIDTH: usize, F> RippleCarryAdderMod2w<BITWIDTH, F>
101where
102 F: Field,
103{
104 pub fn new(
109 builder_ref: &mut CircuitBuilder<F>,
110 x_word: &NodeRef<F>,
111 y_word: &NodeRef<F>,
112 ) -> Self {
113 debug_assert!(BITWIDTH.is_power_of_two());
114 debug_assert!(x_word.get_num_vars() == BITWIDTH.ilog2() as usize);
115 debug_assert!(y_word.get_num_vars() == BITWIDTH.ilog2() as usize);
116 let num_vars = BITWIDTH.ilog2() as usize;
117 let x_word_wires = builder_ref.add_split_node(x_word, num_vars);
118 let y_word_wires = builder_ref.add_split_node(y_word, num_vars);
119 let mut c_in = builder_ref.add_sector(AbstractExpression::constant(F::from(0)));
120
121 let mut sum_expr = Vec::<NodeRef<F>>::with_capacity(BITWIDTH);
122
123 for (x, y) in x_word_wires.iter().zip(y_word_wires.iter()).rev() {
124 let fa = FullAdder::new(builder_ref, x, y, &c_in);
126 let (s, c) = fa.get_output();
127 c_in = c;
128 sum_expr.push(s);
129 }
130
131 let c_out_expr = c_in.expr();
133 let one_or_zero =
134 c_out_expr.clone() * (AbstractExpression::constant(F::from(1)) - c_out_expr);
135 let carry_sector = builder_ref.add_sector(one_or_zero);
136 builder_ref.set_output(&carry_sector);
137
138 let sum_rewired = sum_expr.iter().rev().map(|n| n.expr()).collect_vec();
140
141 let sum_node =
142 builder_ref.add_sector(AbstractExpression::<F>::binary_tree_selector(sum_rewired));
143
144 Self {
145 sum_node,
146 _phantom: Default::default(),
147 }
148 }
149}