remainder/prover/
layers.rs

1use std::marker::PhantomData;
2
3use shared_types::Field;
4
5use crate::{
6    layer::{
7        gate::{BinaryOperation, GateLayer},
8        Layer, LayerId,
9    },
10    mle::dense::DenseMle,
11};
12
13use crate::mle::Mle;
14
15#[derive(Clone, Debug)]
16/// The list of Layers that make up the GKR circuit
17pub struct Layers<F: Field, T: Layer<F>> {
18    /// A Vec of pointers to various layer types
19    pub layers: Vec<T>,
20    marker: PhantomData<F>,
21}
22
23impl<F: Field, T: Layer<F>> Layers<F, T> {
24    /// Add a batched Add Gate layer to a list of layers
25    /// In the batched case, consider a vector of mles corresponding to an mle for each "batch" or "copy".
26    /// Add a Gate layer to a list of layers
27    /// In the batched case (`num_dataparallel_bits` > 0), consider a vector of mles corresponding to an mle for each "batch" or "copy".
28    /// Then we refer to the mle that represents the concatenation of these mles by interleaving as the
29    /// flattened mle and each individual mle as a batched mle.
30    ///
31    /// # Arguments
32    /// * `nonzero_gates`: the gate wiring between single-copy circuit (as the wiring for each circuit remains the same)
33    ///
34    /// x is the label on the batched mle `lhs`, y is the label on the batched mle `rhs`, and z is the label on the next layer, batched
35    /// * `lhs`: the flattened mle representing the left side of the summation
36    /// * `rhs`: the flattened mle representing the right side of the summation
37    /// * `num_dataparallel_bits`: the number of bits representing the circuit copy we are looking at
38    /// * `gate_operation`: which operation the gate is performing. right now, can either be an 'add' or 'mul' gate
39    ///
40    /// # Returns
41    /// A flattened `DenseMle` that represents the evaluations of the add gate wiring on `lhs` and `rhs` over the boolean hypercube
42    pub fn add_gate(
43        &mut self,
44        nonzero_gates: Vec<(u32, u32, u32)>,
45        lhs: DenseMle<F>,
46        rhs: DenseMle<F>,
47        num_dataparallel_bits: Option<usize>,
48        gate_operation: BinaryOperation,
49    ) -> DenseMle<F>
50    where
51        T: From<GateLayer<F>>,
52    {
53        let id = LayerId::Layer(self.layers.len());
54        // constructor for batched mul gate struct
55        let gate: GateLayer<F> = GateLayer::new(
56            num_dataparallel_bits,
57            nonzero_gates.clone(),
58            lhs.clone(),
59            rhs.clone(),
60            gate_operation,
61            id,
62        );
63        let max_gate_val = nonzero_gates
64            .clone()
65            .into_iter()
66            .fold(0, |acc, (z, _, _)| std::cmp::max(acc, z));
67
68        // number of entries in the resulting table is the max gate z value * 2 to the power of the number of dataparallel bits, as we are
69        // evaluating over all values in the boolean hypercube which includes dataparallel bits
70        let num_dataparallel_vals = 1 << (num_dataparallel_bits.unwrap_or(0));
71        let res_table_num_entries = (max_gate_val + 1) * num_dataparallel_vals;
72        self.layers.push(gate.into());
73
74        // iterate through each of the indices and perform the binary operation specified
75        let mut res_table = vec![F::ZERO; res_table_num_entries as usize];
76        (0..num_dataparallel_vals).for_each(|idx| {
77            nonzero_gates
78                .clone()
79                .into_iter()
80                .for_each(|(z_ind, x_ind, y_ind)| {
81                    let f2_val = lhs
82                        .get((idx + (x_ind * num_dataparallel_vals)) as usize)
83                        .unwrap_or(F::ZERO);
84                    let f3_val = rhs
85                        .get((idx + (y_ind * num_dataparallel_vals)) as usize)
86                        .unwrap_or(F::ZERO);
87                    res_table[(idx + (z_ind * num_dataparallel_vals)) as usize] =
88                        gate_operation.perform_operation(f2_val, f3_val);
89                });
90        });
91
92        let res_mle: DenseMle<F> = DenseMle::new_from_raw(res_table, id);
93
94        res_mle
95    }
96
97    /// Creates a new Layers
98    pub fn new() -> Self {
99        Self {
100            layers: Vec::new(),
101            marker: PhantomData,
102        }
103    }
104
105    /// Creates a new [Layers] struct with populated layers values.
106    pub fn new_with_layers(layers: Vec<T>) -> Self {
107        Self {
108            layers,
109            marker: PhantomData,
110        }
111    }
112
113    /// Returns the number of layers in the GKR circuit
114    pub fn num_layers(&self) -> usize {
115        self.layers.len()
116    }
117}
118
119impl<F: Field, T: Layer<F>> Default for Layers<F, T> {
120    fn default() -> Self {
121        Self::new()
122    }
123}