frontend/components/sha2_gkr/
nonlinear_gates.rs

1//! Implementation of Different non-linear Gates used in SHA-2 family of
2//! circuits as described in NIST SP-180-4
3//! <https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf>
4
5#![allow(non_snake_case)]
6use crate::{
7    abstract_expr::AbstractExpression,
8    components::binary_operations::{logical_shift::ShiftNode, rotate_bits::RotateNode},
9    layouter::builder::{Circuit, CircuitBuilder, NodeRef},
10};
11use itertools::Itertools;
12use remainder::mle::evals::MultilinearExtension;
13use shared_types::Field;
14use std::ops::{BitOr, BitXor, Shl, Shr};
15
16/// A trait to deal with bit decomposition and bit rotation as needed by
17/// SHA-2 family of hash functions.
18pub trait IsBitDecomposable:
19    Shl<usize, Output = Self> + Shr<usize, Output = Self> + BitOr<Output = Self> + Sized + Copy
20{
21    /// Gets `index`-th bit from data
22    fn get_bit(&self, index: usize) -> Self;
23
24    /// Rotate bits right by `index` amount.
25    fn rotr(&self, index: usize) -> Self {
26        let bit_count = 8 * std::mem::size_of::<Self>();
27        let rotation = index % bit_count;
28        let delta = bit_count - index;
29        (*self >> rotation) | (*self << delta)
30    }
31
32    /// Rotate bits left
33    fn rotl(&self, index: usize) -> Self {
34        let bit_count = 8 * std::mem::size_of::<Self>();
35        let rotation = index % bit_count;
36        let delta = bit_count - index;
37        (*self << rotation) | (*self >> delta)
38    }
39}
40
41impl IsBitDecomposable for i8 {
42    fn get_bit(&self, index: usize) -> Self {
43        assert!((0..8).contains(&index));
44        (*self >> index) & 0x1
45    }
46}
47
48impl IsBitDecomposable for u8 {
49    fn get_bit(&self, index: usize) -> Self {
50        assert!((0..8).contains(&index));
51        (*self >> index) & 0x1
52    }
53}
54
55impl IsBitDecomposable for i16 {
56    fn get_bit(&self, index: usize) -> Self {
57        assert!((0..16).contains(&index));
58        (*self >> index) & 0x1
59    }
60}
61
62impl IsBitDecomposable for u16 {
63    fn get_bit(&self, index: usize) -> Self {
64        assert!((0..16).contains(&index));
65        (*self >> index) & 0x1
66    }
67}
68
69impl IsBitDecomposable for i32 {
70    fn get_bit(&self, index: usize) -> Self {
71        assert!((0..32).contains(&index));
72        (*self >> index) & 0x1
73    }
74}
75
76impl IsBitDecomposable for u32 {
77    fn get_bit(&self, index: usize) -> Self {
78        assert!((0..32).contains(&index));
79        (*self >> index) & 0x1
80    }
81}
82
83impl IsBitDecomposable for i64 {
84    fn get_bit(&self, index: usize) -> Self {
85        assert!((0..64).contains(&index));
86        (*self >> index) & 0x1
87    }
88}
89
90impl IsBitDecomposable for u64 {
91    fn get_bit(&self, index: usize) -> Self {
92        assert!((0..64).contains(&index));
93        (*self >> index) & 0x1
94    }
95}
96
97const fn sha_words_2_num_vars(value: usize) -> usize {
98    if value == 32 {
99        5
100    } else if value == 64 {
101        6
102    } else {
103        panic!("Invalid SHA wordsize")
104    }
105}
106
107/// Decompose a numerical type in bits: MSB first
108#[inline]
109pub fn bit_decompose_msb_first<T>(input: T) -> Vec<T>
110where
111    T: IsBitDecomposable,
112{
113    let bit_count = 8 * std::mem::size_of::<T>();
114    let mut result = Vec::<T>::with_capacity(bit_count);
115
116    for i in 0..bit_count {
117        let v = input.get_bit(bit_count - 1 - i);
118        result.push(v);
119    }
120
121    result
122}
123
124/// Decompose a numerical type in bits: LSB first
125#[inline]
126pub fn bit_decompose_lsb_first<T>(input: T) -> Vec<T>
127where
128    T: IsBitDecomposable,
129{
130    let bit_count = 8 * std::mem::size_of::<T>();
131    let mut result = Vec::<T>::with_capacity(bit_count);
132
133    for i in 0..bit_count {
134        let v = input.get_bit(i);
135        result.push(v);
136    }
137
138    result
139}
140
141/// Represents a constant input to the circuit. Unlike other gates, the
142/// ConstInputGate takes as input  a name for the constant and its value
143/// and creates an input shred as well the MLE of input data that can be
144/// bound to the circuit at a later time.
145#[derive(Clone, Debug)]
146pub struct ConstInputGate<F: Field> {
147    data_node: NodeRef<F>,
148    bits_mle: MultilinearExtension<F>,
149    constant_name: String,
150}
151
152impl<F: Field> ConstInputGate<F> {
153    /// Creates a constant input gate, with name `constant_name` with
154    /// value `constant_value`. When binding to the Circuit, the same
155    /// `constant_name` should be used in call to Circuit::set_input()
156    pub fn new<T>(
157        builder_ref: &mut CircuitBuilder<F>,
158        constant_name: &str,
159        constant_value: T,
160    ) -> Self
161    where
162        T: IsBitDecomposable,
163        u64: From<T>,
164    {
165        let input_layer = builder_ref.add_input_layer(
166            constant_name,
167            crate::layouter::builder::LayerVisibility::Public,
168        );
169
170        let bits = bit_decompose_msb_first(constant_value);
171        let num_vars = bits.len().ilog2() as usize;
172
173        let bits_mle =
174            MultilinearExtension::new(bits.into_iter().map(u64::from).map(F::from).collect_vec());
175
176        let data_node = builder_ref.add_input_shred(constant_name, num_vars, &input_layer);
177
178        // Make sure inputs are all 1s or zero 0s by creating an assert0
179        // check over x*(1-x).
180        let b = &data_node;
181        let b_sq = AbstractExpression::products(vec![b.id(), b.id()]);
182        let b = b.expr();
183        let binary_sector = builder_ref.add_sector(b - b_sq);
184        builder_ref.set_output(&binary_sector);
185
186        Self {
187            data_node,
188            bits_mle,
189            constant_name: constant_name.into(),
190        }
191    }
192
193    /// Given an instantiated circuit, adds the constant gate to the
194    /// circuit as input with correct input label name.
195    pub fn add_to_circuit(&self, circuit: &mut Circuit<F>) {
196        circuit.set_input(&self.constant_name, self.bits_mle.clone());
197    }
198
199    /// Returns the MLE of bit-decomposition of the constant data
200    pub fn input_mle(&self) -> &MultilinearExtension<F> {
201        &self.bits_mle
202    }
203
204    /// Returns the node that represents this constant. If the same
205    /// constant is used in multiple places, this allows re-using the
206    /// same constant input gate.
207    pub fn get_output(&self) -> NodeRef<F> {
208        self.data_node.clone()
209    }
210
211    /// Returns the constant gate nodes
212    pub fn get_output_ref(&self) -> &NodeRef<F> {
213        &self.data_node
214    }
215}
216
217/// Multiplexer gate as defined by SHA-2 family of circuits.
218#[derive(Clone, Debug)]
219pub struct ChGate<F: Field> {
220    ch_sector: NodeRef<F>,
221}
222
223impl<F: Field> ChGate<F> {
224    /// Computes bit_wise selection of y_vars or x_vars as defined in
225    /// SHA-2 spec. Assumes inputs to the gate are normalized (i.e. in
226    /// {0,1}).
227    pub fn new(
228        builder_ref: &mut CircuitBuilder<F>,
229        x_vars: &NodeRef<F>,
230        y_vars: &NodeRef<F>,
231        z_vars: &NodeRef<F>,
232    ) -> Self {
233        debug_assert!(x_vars.get_num_vars() == 5 || x_vars.get_num_vars() == 6);
234        debug_assert!(y_vars.get_num_vars() == 5 || x_vars.get_num_vars() == 6);
235        debug_assert!(z_vars.get_num_vars() == 5 || x_vars.get_num_vars() == 6);
236
237        assert!(x_vars.get_num_vars() == y_vars.get_num_vars());
238        assert!(x_vars.get_num_vars() == z_vars.get_num_vars());
239
240        // Compute x `and` y
241        let x_AND_y = AbstractExpression::products(vec![x_vars.clone().id(), y_vars.clone().id()]);
242
243        // Compute NOT x = 1 - x
244        let NOT_x = builder_ref.add_sector(AbstractExpression::constant(F::ONE) - x_vars.expr());
245
246        // Compute (x `and` y) `xor` (NOT x `and` Z). Note that x only
247        // selects one bit from either x or z, so the output is
248        // guaranteed to be in {0,1}. Therefore safe to not do the full
249        // xor normalization which will require 32 or 64
250        // multiplications.
251        let ch_sector = builder_ref.add_sector(x_AND_y + (z_vars.expr() * NOT_x.expr()));
252
253        Self { ch_sector }
254    }
255
256    /// Returns the output values of the ChGate in MSB first
257    /// bit-decomposed form
258    pub fn get_output(&self) -> NodeRef<F> {
259        self.ch_sector.clone()
260    }
261
262    /// Computes ch(x,y,z) natively
263    pub const fn evaluate(x: u32, y: u32, z: u32) -> u32 {
264        (x & y) ^ (!x & z)
265    }
266}
267
268/// Bitwise majority selector gate as defined by SHA-2 family of Hash
269/// functions.
270#[derive(Clone, Debug)]
271pub struct MajGate<F: Field> {
272    maj_sector: NodeRef<F>,
273}
274
275impl<F: Field> MajGate<F> {
276    /// Compute bit-wise majority of `x_vars`, `y_vars`, and `z_vars`.
277    pub fn new(
278        builder_ref: &mut CircuitBuilder<F>,
279        x_vars: &NodeRef<F>,
280        y_vars: &NodeRef<F>,
281        z_vars: &NodeRef<F>,
282    ) -> Self {
283        debug_assert!(x_vars.get_num_vars() == 5 || x_vars.get_num_vars() == 6);
284        debug_assert!(y_vars.get_num_vars() == 5 || x_vars.get_num_vars() == 6);
285        debug_assert!(z_vars.get_num_vars() == 5 || x_vars.get_num_vars() == 6);
286
287        assert!(x_vars.get_num_vars() == y_vars.get_num_vars());
288        assert!(x_vars.get_num_vars() == z_vars.get_num_vars());
289
290        // We need the gates to produce normalize output (i.e., output
291        // in {0,1} basis) therefore the arithmetization is
292        //
293        // maj(x,y,z) = x*y + y*z + x*z - 2*x*y*z*(x + y + z - 2*x*y*z)
294        //
295
296        let const_2 = AbstractExpression::constant(F::from(2));
297        let xy = x_vars.expr() * y_vars.expr();
298        let yz = y_vars.expr() * z_vars.expr();
299        let xz = x_vars.expr() * z_vars.expr();
300        let xyz = xy.clone() * z_vars.expr();
301        let x_p_y_p_z =
302            x_vars.expr() + y_vars.expr() + z_vars.expr() - const_2.clone() * xyz.clone();
303
304        let maj_sector = builder_ref.add_sector(xy + yz + xz - const_2 * xyz * x_p_y_p_z);
305
306        Self { maj_sector }
307    }
308
309    /// Returns the output values of MajGate in MSB first bit-decomposed
310    /// form
311    pub fn get_output(&self) -> NodeRef<F> {
312        self.maj_sector.clone()
313    }
314
315    /// Computes maj(x,y,z) natively
316    pub const fn evaluate(x: u32, y: u32, z: u32) -> u32 {
317        (x & y) ^ (y & z) ^ (x & z)
318    }
319}
320
321/// The capital Sigma function described on Printed Page Number 10 in
322/// NIST SP-180.4. The const parameters ROTR1, ROTR2, ROTR3 denote the
323/// rotations defined in NIST spec. Bits are assumed to be in MSB first
324/// decomposition form.
325#[derive(Clone, Debug)]
326pub struct Sigma<
327    F: Field,
328    const WORD_SIZE: usize,
329    const ROTR1: i32,
330    const ROTR2: i32,
331    const ROTR3: i32,
332> {
333    sigma_sector: NodeRef<F>,
334}
335
336impl<F: Field, const WORD_SIZE: usize, const ROTR1: i32, const ROTR2: i32, const ROTR3: i32>
337    Sigma<F, WORD_SIZE, ROTR1, ROTR2, ROTR3>
338{
339    /// Compute capital Sigma of `x_vars`
340    pub fn new(builder_ref: &mut CircuitBuilder<F>, x_vars: &NodeRef<F>) -> Self {
341        let num_vars: usize = sha_words_2_num_vars(WORD_SIZE);
342        let rotr1 = RotateNode::new(builder_ref, num_vars, ROTR1, x_vars);
343        let rotr2 = RotateNode::new(builder_ref, num_vars, ROTR2, x_vars);
344        let rotr3 = RotateNode::new(builder_ref, num_vars, ROTR3, x_vars);
345
346        let r1_expr = rotr1.get_output().expr();
347        let r2_expr = rotr2.get_output().expr();
348        let r3_expr = rotr3.get_output().expr();
349        let r1_xor_r2 = r1_expr.clone() + r2_expr.clone()
350            - AbstractExpression::constant(F::from(2)) * r1_expr * r2_expr;
351
352        let r1_xor_r2_xor_r3 = r1_xor_r2.clone() + r3_expr.clone()
353            - AbstractExpression::constant(F::from(2)) * r1_xor_r2 * r3_expr;
354
355        let sigma_sector = builder_ref.add_sector(r1_xor_r2_xor_r3);
356
357        Self { sigma_sector }
358    }
359
360    /// Get output of capital Sigma gate in MSB-first bit decomposed form
361    pub fn get_output(&self) -> NodeRef<F> {
362        self.sigma_sector.clone()
363    }
364
365    /// Evaluate the capital Sigma gate for given x_data
366    pub fn evaluate<T>(x_data: T) -> T
367    where
368        T: IsBitDecomposable + BitXor<Output = T>,
369    {
370        let rotr1 = x_data.rotr(ROTR1 as usize);
371        let rotr2 = x_data.rotr(ROTR2 as usize);
372        let rotr3 = x_data.rotr(ROTR3 as usize);
373        rotr1 ^ rotr2 ^ rotr3
374    }
375}
376
377/// The Small Sigma function described on Printed Page Number 10
378/// in NIST SP-180.4. The const parameters have following meaning
379///  ROTR1 : Value of rotation in first ROTR
380///  ROTR2 : Value of rotation in second ROTR
381///  SHR3 : Value of rotation in third SHR
382/// MSB-first bit decomposition required.
383#[derive(Clone, Debug)]
384pub struct SmallSigma<
385    F: Field,
386    const WORD_SIZE: usize,
387    const ROTR1: i32,
388    const ROTR2: i32,
389    const SHR3: i32,
390> {
391    sigma_sector: NodeRef<F>,
392}
393
394impl<F: Field, const WORD_SIZE: usize, const ROTR1: i32, const ROTR2: i32, const SHR3: i32>
395    SmallSigma<F, WORD_SIZE, ROTR1, ROTR2, SHR3>
396{
397    /// Compute small Sigma of `x_vars`
398    pub fn new(builder_ref: &mut CircuitBuilder<F>, x_vars: &NodeRef<F>) -> Self {
399        let num_vars: usize = sha_words_2_num_vars(WORD_SIZE);
400        let rotr1 = RotateNode::new(builder_ref, num_vars, ROTR1, x_vars);
401        let rotr2 = RotateNode::new(builder_ref, num_vars, ROTR2, x_vars);
402        let shr3 = ShiftNode::new(builder_ref, num_vars, SHR3, x_vars);
403
404        let r1_expr = rotr1.get_output().expr();
405        let r2_expr = rotr2.get_output().expr();
406        let s3_expr = shr3.get_output().expr();
407
408        let r1_xor_r2 = r1_expr.clone() + r2_expr.clone()
409            - AbstractExpression::constant(F::from(2)) * r1_expr * r2_expr;
410
411        let r1_xor_r2_xor_s3 = r1_xor_r2.clone() + s3_expr.clone()
412            - AbstractExpression::constant(F::from(2)) * r1_xor_r2 * s3_expr;
413
414        let sigma_sector = builder_ref.add_sector(r1_xor_r2_xor_s3);
415
416        Self { sigma_sector }
417    }
418
419    /// Return output of Small Sigma
420    pub fn get_output(&self) -> NodeRef<F> {
421        self.sigma_sector.clone()
422    }
423
424    /// Evaluation with actual input data
425    pub fn evaluate<T>(x_data: T) -> T
426    where
427        T: IsBitDecomposable + BitXor<Output = T>,
428    {
429        let rotr1 = x_data.rotr(ROTR1 as usize);
430        let rotr2 = x_data.rotr(ROTR2 as usize);
431        let shr3 = x_data.shr(SHR3 as usize);
432        rotr1 ^ rotr2 ^ shr3
433    }
434}