frontend/components/sha2_gkr/
nonlinear_gates.rs1#![allow(non_snake_case)]
6use crate::{
7 abstract_expr::AbstractExpression,
8 components::binary_operations::{logical_shift::ShiftNode, rotate_bits::RotateNode},
9 layouter::builder::{Circuit, CircuitBuilder, NodeRef},
10};
11use itertools::Itertools;
12use remainder::mle::evals::MultilinearExtension;
13use shared_types::Field;
14use std::ops::{BitOr, BitXor, Shl, Shr};
15
16pub trait IsBitDecomposable:
19 Shl<usize, Output = Self> + Shr<usize, Output = Self> + BitOr<Output = Self> + Sized + Copy
20{
21 fn get_bit(&self, index: usize) -> Self;
23
24 fn rotr(&self, index: usize) -> Self {
26 let bit_count = 8 * std::mem::size_of::<Self>();
27 let rotation = index % bit_count;
28 let delta = bit_count - index;
29 (*self >> rotation) | (*self << delta)
30 }
31
32 fn rotl(&self, index: usize) -> Self {
34 let bit_count = 8 * std::mem::size_of::<Self>();
35 let rotation = index % bit_count;
36 let delta = bit_count - index;
37 (*self << rotation) | (*self >> delta)
38 }
39}
40
41impl IsBitDecomposable for i8 {
42 fn get_bit(&self, index: usize) -> Self {
43 assert!((0..8).contains(&index));
44 (*self >> index) & 0x1
45 }
46}
47
48impl IsBitDecomposable for u8 {
49 fn get_bit(&self, index: usize) -> Self {
50 assert!((0..8).contains(&index));
51 (*self >> index) & 0x1
52 }
53}
54
55impl IsBitDecomposable for i16 {
56 fn get_bit(&self, index: usize) -> Self {
57 assert!((0..16).contains(&index));
58 (*self >> index) & 0x1
59 }
60}
61
62impl IsBitDecomposable for u16 {
63 fn get_bit(&self, index: usize) -> Self {
64 assert!((0..16).contains(&index));
65 (*self >> index) & 0x1
66 }
67}
68
69impl IsBitDecomposable for i32 {
70 fn get_bit(&self, index: usize) -> Self {
71 assert!((0..32).contains(&index));
72 (*self >> index) & 0x1
73 }
74}
75
76impl IsBitDecomposable for u32 {
77 fn get_bit(&self, index: usize) -> Self {
78 assert!((0..32).contains(&index));
79 (*self >> index) & 0x1
80 }
81}
82
83impl IsBitDecomposable for i64 {
84 fn get_bit(&self, index: usize) -> Self {
85 assert!((0..64).contains(&index));
86 (*self >> index) & 0x1
87 }
88}
89
90impl IsBitDecomposable for u64 {
91 fn get_bit(&self, index: usize) -> Self {
92 assert!((0..64).contains(&index));
93 (*self >> index) & 0x1
94 }
95}
96
97const fn sha_words_2_num_vars(value: usize) -> usize {
98 if value == 32 {
99 5
100 } else if value == 64 {
101 6
102 } else {
103 panic!("Invalid SHA wordsize")
104 }
105}
106
107#[inline]
109pub fn bit_decompose_msb_first<T>(input: T) -> Vec<T>
110where
111 T: IsBitDecomposable,
112{
113 let bit_count = 8 * std::mem::size_of::<T>();
114 let mut result = Vec::<T>::with_capacity(bit_count);
115
116 for i in 0..bit_count {
117 let v = input.get_bit(bit_count - 1 - i);
118 result.push(v);
119 }
120
121 result
122}
123
124#[inline]
126pub fn bit_decompose_lsb_first<T>(input: T) -> Vec<T>
127where
128 T: IsBitDecomposable,
129{
130 let bit_count = 8 * std::mem::size_of::<T>();
131 let mut result = Vec::<T>::with_capacity(bit_count);
132
133 for i in 0..bit_count {
134 let v = input.get_bit(i);
135 result.push(v);
136 }
137
138 result
139}
140
141#[derive(Clone, Debug)]
146pub struct ConstInputGate<F: Field> {
147 data_node: NodeRef<F>,
148 bits_mle: MultilinearExtension<F>,
149 constant_name: String,
150}
151
152impl<F: Field> ConstInputGate<F> {
153 pub fn new<T>(
157 builder_ref: &mut CircuitBuilder<F>,
158 constant_name: &str,
159 constant_value: T,
160 ) -> Self
161 where
162 T: IsBitDecomposable,
163 u64: From<T>,
164 {
165 let input_layer = builder_ref.add_input_layer(
166 constant_name,
167 crate::layouter::builder::LayerVisibility::Public,
168 );
169
170 let bits = bit_decompose_msb_first(constant_value);
171 let num_vars = bits.len().ilog2() as usize;
172
173 let bits_mle =
174 MultilinearExtension::new(bits.into_iter().map(u64::from).map(F::from).collect_vec());
175
176 let data_node = builder_ref.add_input_shred(constant_name, num_vars, &input_layer);
177
178 let b = &data_node;
181 let b_sq = AbstractExpression::products(vec![b.id(), b.id()]);
182 let b = b.expr();
183 let binary_sector = builder_ref.add_sector(b - b_sq);
184 builder_ref.set_output(&binary_sector);
185
186 Self {
187 data_node,
188 bits_mle,
189 constant_name: constant_name.into(),
190 }
191 }
192
193 pub fn add_to_circuit(&self, circuit: &mut Circuit<F>) {
196 circuit.set_input(&self.constant_name, self.bits_mle.clone());
197 }
198
199 pub fn input_mle(&self) -> &MultilinearExtension<F> {
201 &self.bits_mle
202 }
203
204 pub fn get_output(&self) -> NodeRef<F> {
208 self.data_node.clone()
209 }
210
211 pub fn get_output_ref(&self) -> &NodeRef<F> {
213 &self.data_node
214 }
215}
216
217#[derive(Clone, Debug)]
219pub struct ChGate<F: Field> {
220 ch_sector: NodeRef<F>,
221}
222
223impl<F: Field> ChGate<F> {
224 pub fn new(
228 builder_ref: &mut CircuitBuilder<F>,
229 x_vars: &NodeRef<F>,
230 y_vars: &NodeRef<F>,
231 z_vars: &NodeRef<F>,
232 ) -> Self {
233 debug_assert!(x_vars.get_num_vars() == 5 || x_vars.get_num_vars() == 6);
234 debug_assert!(y_vars.get_num_vars() == 5 || x_vars.get_num_vars() == 6);
235 debug_assert!(z_vars.get_num_vars() == 5 || x_vars.get_num_vars() == 6);
236
237 assert!(x_vars.get_num_vars() == y_vars.get_num_vars());
238 assert!(x_vars.get_num_vars() == z_vars.get_num_vars());
239
240 let x_AND_y = AbstractExpression::products(vec![x_vars.clone().id(), y_vars.clone().id()]);
242
243 let NOT_x = builder_ref.add_sector(AbstractExpression::constant(F::ONE) - x_vars.expr());
245
246 let ch_sector = builder_ref.add_sector(x_AND_y + (z_vars.expr() * NOT_x.expr()));
252
253 Self { ch_sector }
254 }
255
256 pub fn get_output(&self) -> NodeRef<F> {
259 self.ch_sector.clone()
260 }
261
262 pub const fn evaluate(x: u32, y: u32, z: u32) -> u32 {
264 (x & y) ^ (!x & z)
265 }
266}
267
268#[derive(Clone, Debug)]
271pub struct MajGate<F: Field> {
272 maj_sector: NodeRef<F>,
273}
274
275impl<F: Field> MajGate<F> {
276 pub fn new(
278 builder_ref: &mut CircuitBuilder<F>,
279 x_vars: &NodeRef<F>,
280 y_vars: &NodeRef<F>,
281 z_vars: &NodeRef<F>,
282 ) -> Self {
283 debug_assert!(x_vars.get_num_vars() == 5 || x_vars.get_num_vars() == 6);
284 debug_assert!(y_vars.get_num_vars() == 5 || x_vars.get_num_vars() == 6);
285 debug_assert!(z_vars.get_num_vars() == 5 || x_vars.get_num_vars() == 6);
286
287 assert!(x_vars.get_num_vars() == y_vars.get_num_vars());
288 assert!(x_vars.get_num_vars() == z_vars.get_num_vars());
289
290 let const_2 = AbstractExpression::constant(F::from(2));
297 let xy = x_vars.expr() * y_vars.expr();
298 let yz = y_vars.expr() * z_vars.expr();
299 let xz = x_vars.expr() * z_vars.expr();
300 let xyz = xy.clone() * z_vars.expr();
301 let x_p_y_p_z =
302 x_vars.expr() + y_vars.expr() + z_vars.expr() - const_2.clone() * xyz.clone();
303
304 let maj_sector = builder_ref.add_sector(xy + yz + xz - const_2 * xyz * x_p_y_p_z);
305
306 Self { maj_sector }
307 }
308
309 pub fn get_output(&self) -> NodeRef<F> {
312 self.maj_sector.clone()
313 }
314
315 pub const fn evaluate(x: u32, y: u32, z: u32) -> u32 {
317 (x & y) ^ (y & z) ^ (x & z)
318 }
319}
320
321#[derive(Clone, Debug)]
326pub struct Sigma<
327 F: Field,
328 const WORD_SIZE: usize,
329 const ROTR1: i32,
330 const ROTR2: i32,
331 const ROTR3: i32,
332> {
333 sigma_sector: NodeRef<F>,
334}
335
336impl<F: Field, const WORD_SIZE: usize, const ROTR1: i32, const ROTR2: i32, const ROTR3: i32>
337 Sigma<F, WORD_SIZE, ROTR1, ROTR2, ROTR3>
338{
339 pub fn new(builder_ref: &mut CircuitBuilder<F>, x_vars: &NodeRef<F>) -> Self {
341 let num_vars: usize = sha_words_2_num_vars(WORD_SIZE);
342 let rotr1 = RotateNode::new(builder_ref, num_vars, ROTR1, x_vars);
343 let rotr2 = RotateNode::new(builder_ref, num_vars, ROTR2, x_vars);
344 let rotr3 = RotateNode::new(builder_ref, num_vars, ROTR3, x_vars);
345
346 let r1_expr = rotr1.get_output().expr();
347 let r2_expr = rotr2.get_output().expr();
348 let r3_expr = rotr3.get_output().expr();
349 let r1_xor_r2 = r1_expr.clone() + r2_expr.clone()
350 - AbstractExpression::constant(F::from(2)) * r1_expr * r2_expr;
351
352 let r1_xor_r2_xor_r3 = r1_xor_r2.clone() + r3_expr.clone()
353 - AbstractExpression::constant(F::from(2)) * r1_xor_r2 * r3_expr;
354
355 let sigma_sector = builder_ref.add_sector(r1_xor_r2_xor_r3);
356
357 Self { sigma_sector }
358 }
359
360 pub fn get_output(&self) -> NodeRef<F> {
362 self.sigma_sector.clone()
363 }
364
365 pub fn evaluate<T>(x_data: T) -> T
367 where
368 T: IsBitDecomposable + BitXor<Output = T>,
369 {
370 let rotr1 = x_data.rotr(ROTR1 as usize);
371 let rotr2 = x_data.rotr(ROTR2 as usize);
372 let rotr3 = x_data.rotr(ROTR3 as usize);
373 rotr1 ^ rotr2 ^ rotr3
374 }
375}
376
377#[derive(Clone, Debug)]
384pub struct SmallSigma<
385 F: Field,
386 const WORD_SIZE: usize,
387 const ROTR1: i32,
388 const ROTR2: i32,
389 const SHR3: i32,
390> {
391 sigma_sector: NodeRef<F>,
392}
393
394impl<F: Field, const WORD_SIZE: usize, const ROTR1: i32, const ROTR2: i32, const SHR3: i32>
395 SmallSigma<F, WORD_SIZE, ROTR1, ROTR2, SHR3>
396{
397 pub fn new(builder_ref: &mut CircuitBuilder<F>, x_vars: &NodeRef<F>) -> Self {
399 let num_vars: usize = sha_words_2_num_vars(WORD_SIZE);
400 let rotr1 = RotateNode::new(builder_ref, num_vars, ROTR1, x_vars);
401 let rotr2 = RotateNode::new(builder_ref, num_vars, ROTR2, x_vars);
402 let shr3 = ShiftNode::new(builder_ref, num_vars, SHR3, x_vars);
403
404 let r1_expr = rotr1.get_output().expr();
405 let r2_expr = rotr2.get_output().expr();
406 let s3_expr = shr3.get_output().expr();
407
408 let r1_xor_r2 = r1_expr.clone() + r2_expr.clone()
409 - AbstractExpression::constant(F::from(2)) * r1_expr * r2_expr;
410
411 let r1_xor_r2_xor_s3 = r1_xor_r2.clone() + s3_expr.clone()
412 - AbstractExpression::constant(F::from(2)) * r1_xor_r2 * s3_expr;
413
414 let sigma_sector = builder_ref.add_sector(r1_xor_r2_xor_s3);
415
416 Self { sigma_sector }
417 }
418
419 pub fn get_output(&self) -> NodeRef<F> {
421 self.sigma_sector.clone()
422 }
423
424 pub fn evaluate<T>(x_data: T) -> T
426 where
427 T: IsBitDecomposable + BitXor<Output = T>,
428 {
429 let rotr1 = x_data.rotr(ROTR1 as usize);
430 let rotr2 = x_data.rotr(ROTR2 as usize);
431 let shr3 = x_data.shr(SHR3 as usize);
432 rotr1 ^ rotr2 ^ shr3
433 }
434}