remainder/prover/layers.rs
1use std::marker::PhantomData;
2
3use shared_types::Field;
4
5use crate::{
6 layer::{
7 gate::{BinaryOperation, GateLayer},
8 Layer, LayerId,
9 },
10 mle::dense::DenseMle,
11};
12
13use crate::mle::Mle;
14
15#[derive(Clone, Debug)]
16/// The list of Layers that make up the GKR circuit
17pub struct Layers<F: Field, T: Layer<F>> {
18 /// A Vec of pointers to various layer types
19 pub layers: Vec<T>,
20 marker: PhantomData<F>,
21}
22
23impl<F: Field, T: Layer<F>> Layers<F, T> {
24 /// Add a batched Add Gate layer to a list of layers
25 /// In the batched case, consider a vector of mles corresponding to an mle for each "batch" or "copy".
26 /// Add a Gate layer to a list of layers
27 /// In the batched case (`num_dataparallel_bits` > 0), consider a vector of mles corresponding to an mle for each "batch" or "copy".
28 /// Then we refer to the mle that represents the concatenation of these mles by interleaving as the
29 /// flattened mle and each individual mle as a batched mle.
30 ///
31 /// # Arguments
32 /// * `nonzero_gates`: the gate wiring between single-copy circuit (as the wiring for each circuit remains the same)
33 ///
34 /// x is the label on the batched mle `lhs`, y is the label on the batched mle `rhs`, and z is the label on the next layer, batched
35 /// * `lhs`: the flattened mle representing the left side of the summation
36 /// * `rhs`: the flattened mle representing the right side of the summation
37 /// * `num_dataparallel_bits`: the number of bits representing the circuit copy we are looking at
38 /// * `gate_operation`: which operation the gate is performing. right now, can either be an 'add' or 'mul' gate
39 ///
40 /// # Returns
41 /// A flattened `DenseMle` that represents the evaluations of the add gate wiring on `lhs` and `rhs` over the boolean hypercube
42 pub fn add_gate(
43 &mut self,
44 nonzero_gates: Vec<(u32, u32, u32)>,
45 lhs: DenseMle<F>,
46 rhs: DenseMle<F>,
47 num_dataparallel_bits: Option<usize>,
48 gate_operation: BinaryOperation,
49 ) -> DenseMle<F>
50 where
51 T: From<GateLayer<F>>,
52 {
53 let id = LayerId::Layer(self.layers.len());
54 // constructor for batched mul gate struct
55 let gate: GateLayer<F> = GateLayer::new(
56 num_dataparallel_bits,
57 nonzero_gates.clone(),
58 lhs.clone(),
59 rhs.clone(),
60 gate_operation,
61 id,
62 );
63 let max_gate_val = nonzero_gates
64 .clone()
65 .into_iter()
66 .fold(0, |acc, (z, _, _)| std::cmp::max(acc, z));
67
68 // number of entries in the resulting table is the max gate z value * 2 to the power of the number of dataparallel bits, as we are
69 // evaluating over all values in the boolean hypercube which includes dataparallel bits
70 let num_dataparallel_vals = 1 << (num_dataparallel_bits.unwrap_or(0));
71 let res_table_num_entries = (max_gate_val + 1) * num_dataparallel_vals;
72 self.layers.push(gate.into());
73
74 // iterate through each of the indices and perform the binary operation specified
75 let mut res_table = vec![F::ZERO; res_table_num_entries as usize];
76 (0..num_dataparallel_vals).for_each(|idx| {
77 nonzero_gates
78 .clone()
79 .into_iter()
80 .for_each(|(z_ind, x_ind, y_ind)| {
81 let f2_val = lhs
82 .get((idx + (x_ind * num_dataparallel_vals)) as usize)
83 .unwrap_or(F::ZERO);
84 let f3_val = rhs
85 .get((idx + (y_ind * num_dataparallel_vals)) as usize)
86 .unwrap_or(F::ZERO);
87 res_table[(idx + (z_ind * num_dataparallel_vals)) as usize] =
88 gate_operation.perform_operation(f2_val, f3_val);
89 });
90 });
91
92 let res_mle: DenseMle<F> = DenseMle::new_from_raw(res_table, id);
93
94 res_mle
95 }
96
97 /// Creates a new Layers
98 pub fn new() -> Self {
99 Self {
100 layers: Vec::new(),
101 marker: PhantomData,
102 }
103 }
104
105 /// Creates a new [Layers] struct with populated layers values.
106 pub fn new_with_layers(layers: Vec<T>) -> Self {
107 Self {
108 layers,
109 marker: PhantomData,
110 }
111 }
112
113 /// Returns the number of layers in the GKR circuit
114 pub fn num_layers(&self) -> usize {
115 self.layers.len()
116 }
117}
118
119impl<F: Field, T: Layer<F>> Default for Layers<F, T> {
120 fn default() -> Self {
121 Self::new()
122 }
123}