frontend/components/sha2_gkr/
brent_kung_adder.rs1use super::AdderGateTrait;
6use crate::{
7 abstract_expr::AbstractExpression,
8 layouter::builder::{CircuitBuilder, InputLayerNodeRef, NodeRef},
9};
10use shared_types::Field;
11use std::marker::PhantomData;
12
13#[inline(always)]
14fn mul<F: Field>(a: AbstractExpression<F>, b: AbstractExpression<F>) -> AbstractExpression<F> {
15 a * b
16}
17
18#[inline(always)]
19fn xor<F: Field>(a: AbstractExpression<F>, b: AbstractExpression<F>) -> AbstractExpression<F> {
20 a.clone() + b.clone() - AbstractExpression::constant(F::from(2)) * a * b
21}
22
23fn pp_adder_4_bit<F: Field>(
31 builder_ref: &mut CircuitBuilder<F>,
32 x_val: NodeRef<F>,
33 y_val: NodeRef<F>,
34 carry_in: Option<NodeRef<F>>,
35) -> (NodeRef<F> , NodeRef<F> ) {
36 debug_assert!(x_val.get_num_vars() == 2);
37 debug_assert!(y_val.get_num_vars() == 2);
38 debug_assert!(carry_in
39 .as_ref()
40 .map(|v| v.get_num_vars() == 0)
41 .unwrap_or(true));
42
43 let propagate = builder_ref.add_sector(
44 x_val.expr() + y_val.expr()
45 - AbstractExpression::constant(F::from(2)) * x_val.expr() * y_val.expr(),
46 );
47 let generate = builder_ref.add_sector(x_val.expr() * y_val.expr());
48 let mut p = builder_ref.add_split_node(&propagate, 2);
49 let mut g = builder_ref.add_split_node(&generate, 2);
50
51 p.reverse();
53 g.reverse();
54
55 assert!(p.len() == 4);
56 assert!(g.len() == 4);
57
58 let p1g0 = mul(p[1].expr(), g[0].expr());
60 let p0p1 = mul(p[0].expr(), p[1].expr());
61 let p2p3 = mul(p[2].expr(), p[3].expr());
62
63 let g10 = xor(g[1].expr(), p1g0.clone());
64 let g20 = mul(p[2].expr(), g10.clone());
65 let g20 = xor(g[2].expr(), g20.clone());
66 let g30 = mul(p[3].expr(), g20.clone());
67 let g30 = xor(g[3].expr(), g30.clone());
68
69 let c0 = carry_in
71 .map(|v| v.expr())
72 .unwrap_or(AbstractExpression::constant(F::ZERO));
73 let tmp = mul(p[0].expr(), c0.clone());
74 let c1 = xor(g[0].expr(), tmp);
75 let tmp = mul(p0p1.clone(), c0.clone());
76 let c2 = xor(g10.clone(), tmp);
77 let tmp = mul(p[2].expr(), c0.clone());
78 let tmp = mul(p0p1.clone(), tmp);
79 let c3 = xor(g20, tmp);
80 let tmp = mul(p0p1, p2p3);
81 let tmp = mul(tmp, c0.clone());
82 let c4 = xor(g30, tmp);
83
84 let sum = vec![
86 xor(p[3].expr(), c3),
87 xor(p[2].expr(), c2),
88 xor(p[1].expr(), c1),
89 xor(p[0].expr(), c0),
90 ];
91
92 (
93 builder_ref.add_sector(AbstractExpression::binary_tree_selector(sum)),
94 builder_ref.add_sector(c4),
95 )
96}
97
98#[derive(Debug, Clone)]
100pub struct BKAdder<const BITWIDTH: usize, F: Field> {
101 sum_node: NodeRef<F>,
102 _phantom: PhantomData<F>,
103}
104
105impl<F: Field> AdderGateTrait<F> for BKAdder<32, F> {
106 type IntegralType = u32;
107
108 fn layout_adder_circuit(
109 circuit_builder: &mut CircuitBuilder<F>, x_node: &NodeRef<F>, y_node: &NodeRef<F>, carry_layer: Option<InputLayerNodeRef<F>>, ) -> Self {
114 Self::new(circuit_builder, x_node, y_node, carry_layer)
115 }
116
117 fn get_output(&self) -> NodeRef<F> {
119 self.sum_node.clone()
120 }
121
122 fn perform_addition(
123 &self,
124 _circuit: &mut crate::layouter::builder::Circuit<F>,
125 x: u32,
126 y: u32,
127 ) -> u32 {
128 x.wrapping_add(y)
129 }
130}
131
132impl<const BITWIDTH: usize, F> BKAdder<BITWIDTH, F>
133where
134 F: Field,
135{
136 fn new(
137 builder_ref: &mut CircuitBuilder<F>,
138 x_word: &NodeRef<F>,
139 y_word: &NodeRef<F>,
140 _carry_layer: Option<InputLayerNodeRef<F>>,
141 ) -> Self {
142 assert!(
143 BITWIDTH.is_multiple_of(8),
144 "Only bitwidths of multiple of 8 are supported"
145 );
146
147 assert!(
148 x_word.get_num_vars() == BITWIDTH.ilog2() as usize,
149 "The number of variables must match"
150 );
151
152 let chunks = BITWIDTH.ilog2().saturating_sub(2) as usize;
153
154 assert!(chunks > 0);
155 let x_4bit_chunks = builder_ref.add_split_node(x_word, chunks);
156 let y_4bit_chunks = builder_ref.add_split_node(y_word, chunks);
157
158 let mut sum_chunks = Vec::<NodeRef<F>>::new();
159 let mut carry: Option<NodeRef<F>> = None;
160
161 x_4bit_chunks
162 .into_iter()
163 .rev()
164 .zip(y_4bit_chunks.into_iter().rev())
165 .for_each(|(x, y)| {
166 let (s, c) = pp_adder_4_bit(builder_ref, x, y, carry.clone());
167 sum_chunks.push(s);
168 carry = Some(c);
169 });
170
171 let final_sum = sum_chunks.into_iter().rev().map(|v| v.expr()).collect();
172
173 let final_carry = carry.map(|c| c.expr()).unwrap();
174
175 let final_carry_is_one_or_zero = builder_ref
176 .add_sector(final_carry.clone() * (AbstractExpression::constant(F::ONE) - final_carry));
177
178 builder_ref.set_output(&final_carry_is_one_or_zero);
179
180 Self {
181 sum_node: builder_ref.add_sector(AbstractExpression::binary_tree_selector(final_sum)),
182 _phantom: Default::default(),
183 }
184 }
185}