frontend/zk_iriscode_ss/
components.rs

1use shared_types::Field;
2
3use crate::abstract_expr::AbstractExpression;
4use crate::layouter::builder::{CircuitBuilder, NodeRef};
5
6/// Components for Zk iris code computation
7pub struct ZkIriscodeComponent;
8
9impl ZkIriscodeComponent {
10    /// Calculates a sum of products of two equal-length vectors of Nodes.  For example, can be used for
11    /// computing a random linear combination of some nodes - in this case, the `lh_multiplicands` would
12    /// be instances of [crate::layouter::nodes::fiat_shamir_challenge::FiatShamirChallengeNode].
13    pub fn sum_of_products<F: Field>(
14        builder_ref: &mut CircuitBuilder<F>,
15        lh_multiplicands: Vec<&NodeRef<F>>,
16        rh_multiplicands: Vec<&NodeRef<F>>,
17    ) -> NodeRef<F> {
18        assert_eq!(lh_multiplicands.len(), rh_multiplicands.len());
19        let sector = builder_ref.add_sector(
20            lh_multiplicands
21                .iter()
22                .zip(rh_multiplicands)
23                .fold(AbstractExpression::constant(F::ZERO), |acc, (lh, rh)| {
24                    acc + AbstractExpression::products(vec![lh.id(), rh.id()])
25                }),
26        );
27        sector
28    }
29}
30
31#[cfg(test)]
32mod test {
33    use shared_types::{Field, Fr};
34
35    use crate::layouter::builder::{Circuit, CircuitBuilder, LayerVisibility};
36    use remainder::{
37        mle::evals::MultilinearExtension,
38        prover::helpers::test_circuit_with_runtime_optimized_config,
39    };
40
41    use super::ZkIriscodeComponent;
42
43    use anyhow::Result;
44
45    fn build_sum_of_products_circuit<F: Field>() -> Result<Circuit<F>> {
46        let mut builder = CircuitBuilder::<F>::new();
47
48        let n_summands = 4;
49        let rh_vector_num_vars = 1;
50        let lh_vector_num_vars = 0;
51        // Vectors to be summed together
52        let rh_input_layer = builder.add_input_layer("RH", LayerVisibility::Public);
53        let rh_input_shreds = (0..n_summands)
54            .map(|i| {
55                builder.add_input_shred(
56                    &format!("RH Input Shred {i}"),
57                    rh_vector_num_vars,
58                    &rh_input_layer,
59                )
60            })
61            .collect::<Vec<_>>();
62        // Coefficients to multiple the vectors by
63        let lh_input_layer = builder.add_input_layer("LH", LayerVisibility::Public);
64        let lh_input_shreds = (0..n_summands)
65            .map(|i| {
66                builder.add_input_shred(
67                    &format!("LH Input Shred {i}"),
68                    lh_vector_num_vars,
69                    &lh_input_layer,
70                )
71            })
72            .collect::<Vec<_>>();
73        let sop = ZkIriscodeComponent::sum_of_products(
74            &mut builder,
75            lh_input_shreds.iter().collect(),
76            rh_input_shreds.iter().collect(),
77        );
78
79        let _output = builder.set_output(&sop);
80
81        builder.build_with_layer_combination()
82    }
83
84    #[test]
85    fn test_sum_of_products() {
86        let mut circuit = build_sum_of_products_circuit::<Fr>().unwrap();
87        [
88            Fr::from(17).neg(),
89            Fr::from(20).neg(),
90            Fr::from(2),
91            Fr::from(1),
92        ]
93        .into_iter()
94        .enumerate()
95        .for_each(|(i, elem)| {
96            circuit.set_input(
97                &format!("LH Input Shred {i}"),
98                MultilinearExtension::new(vec![elem]),
99            );
100        });
101
102        [
103            vec![Fr::from(1), Fr::from(0)],
104            vec![Fr::from(0), Fr::from(1)],
105            vec![Fr::from(5), Fr::from(6)],
106            vec![Fr::from(7), Fr::from(8)],
107        ]
108        .into_iter()
109        .enumerate()
110        .for_each(|(i, mle)| {
111            circuit.set_input(
112                &format!("RH Input Shred {i}"),
113                MultilinearExtension::new(mle),
114            );
115        });
116
117        let provable_circuit = circuit.gen_provable_circuit().unwrap();
118
119        test_circuit_with_runtime_optimized_config(&provable_circuit);
120    }
121}