frontend/zk_iriscode_ss/
circuits.rs

1#![allow(clippy::type_complexity)]
2
3use crate::components::digits::DigitComponents;
4use crate::layouter::builder::{Circuit, CircuitBuilder, LayerVisibility};
5use crate::zk_iriscode_ss::components::ZkIriscodeComponent;
6use crate::zk_iriscode_ss::data::IriscodeCircuitAuxData;
7use remainder::mle::evals::MultilinearExtension;
8use remainder::utils::arithmetic::log2_ceil;
9
10use itertools::Itertools;
11use shared_types::Field;
12
13use super::data::IriscodeCircuitInputData;
14
15use anyhow::Result;
16
17/// The input layer for the image (typically private).
18pub const V3_INPUT_IMAGE_LAYER: &str = "Input image (to reroute)";
19/// The input layer for the digit values and the digit multiplicities (typically private).
20pub const V3_DIGITS_LAYER: &str = "Digit values and multiplicities";
21/// The input layer for the iris/mask code.
22pub const V3_SIGN_BITS_LAYER: &str = "Sign Bits";
23/// All the other public inputs (lookup table values, to sub from matmult & RH multiplicand of matmult).
24pub const V3_AUXILIARY_LAYER: &str = "Auxiliary Data";
25
26pub const V3_INPUT_IMAGE_SHRED: &str = "Image to reroute";
27pub const V3_DIGITS_SHRED_TEMPLATE: &str = "Digits Input Shred";
28pub const V3_DIGITS_MULTIPLICITIES_SHRED: &str = "Digits multiplicities";
29pub const V3_TO_SUB_MATMULT_SHRED: &str = "Input to subtract from MatMult";
30pub const V3_RH_MATMULT_SHRED: &str = "RH Multiplicand of MatMult";
31pub const V3_LOOKUP_SHRED: &str = "Lookup table values for digit range check";
32pub const V3_SIGN_BITS_SHRED: &str = "Sign Bits";
33
34/// Build the [`Circuit<F>`] which is the circuit description of the iris code circuit.
35pub fn build_iriscode_circuit_description<
36    F: Field,
37    const IM_STRIP_ROWS: usize,
38    const IM_STRIP_COLS: usize,
39    const IM_NUM_VARS: usize,
40    const MATMULT_ROWS_NUM_VARS: usize,
41    const MATMULT_COLS_NUM_VARS: usize,
42    const MATMULT_INTERNAL_DIM_NUM_VARS: usize,
43    const BASE: u64,
44    const NUM_DIGITS: usize,
45>(
46    layer_visibility: LayerVisibility,
47    image_strip_reroutings: Vec<Vec<(u32, u32)>>,
48    lh_matrix_reroutings: Vec<(u32, u32)>,
49) -> Result<Circuit<F>> {
50    let mut builder = CircuitBuilder::<F>::new();
51
52    assert!(BASE.is_power_of_two());
53    let log_base = log2_ceil(BASE) as usize;
54    let num_strips = image_strip_reroutings.len();
55    assert!(num_strips.is_power_of_two());
56    let log_num_strips = log2_ceil(num_strips) as usize;
57
58    // Image input layer
59    let to_reroute_input_layer = builder.add_input_layer(V3_INPUT_IMAGE_LAYER, layer_visibility);
60    let to_reroute =
61        builder.add_input_shred(V3_INPUT_IMAGE_SHRED, IM_NUM_VARS, &to_reroute_input_layer);
62
63    // Digits and multiplicities input layer
64    let digits_input_layer = builder.add_input_layer(V3_DIGITS_LAYER, layer_visibility);
65    let digits_input_shreds: Vec<_> = (0..NUM_DIGITS)
66        .map(|i| {
67            builder.add_input_shred(
68                &format!("{V3_DIGITS_SHRED_TEMPLATE} {i}"),
69                log_num_strips + MATMULT_ROWS_NUM_VARS + MATMULT_COLS_NUM_VARS,
70                &digits_input_layer,
71            )
72        })
73        .collect();
74
75    let digit_multiplicities = builder.add_input_shred(
76        V3_DIGITS_MULTIPLICITIES_SHRED,
77        log_base,
78        &digits_input_layer,
79    );
80
81    // Auxiliary inputs
82    let auxiliary_input_layer =
83        builder.add_input_layer(V3_AUXILIARY_LAYER, LayerVisibility::Public);
84
85    let to_sub_from_matmult = builder.add_input_shred(
86        V3_TO_SUB_MATMULT_SHRED,
87        log_num_strips + MATMULT_ROWS_NUM_VARS + MATMULT_COLS_NUM_VARS,
88        &auxiliary_input_layer,
89    );
90
91    let rh_matmult_multiplicand = builder.add_input_shred(
92        V3_RH_MATMULT_SHRED,
93        MATMULT_INTERNAL_DIM_NUM_VARS + MATMULT_COLS_NUM_VARS,
94        &auxiliary_input_layer,
95    );
96
97    let lookup_table_values =
98        builder.add_input_shred(V3_LOOKUP_SHRED, log_base, &auxiliary_input_layer);
99
100    // Sign bits (iris/mask code)
101    let sign_bits_input_layer = builder.add_input_layer(V3_SIGN_BITS_LAYER, layer_visibility);
102    let sign_bits = builder.add_input_shred(
103        V3_SIGN_BITS_SHRED,
104        log_num_strips + MATMULT_ROWS_NUM_VARS + MATMULT_COLS_NUM_VARS,
105        &sign_bits_input_layer,
106    );
107
108    // Verifier challenges for RLC
109    let rlc_challenges = (0..num_strips)
110        .map(|_| builder.add_fiat_shamir_challenge_node(1))
111        .collect_vec();
112    let rlc_challenges_generic = rlc_challenges
113        .clone()
114        .into_iter()
115        .map(|node| node.into())
116        .collect_vec();
117
118    // Verifier challenge for lookup
119    let lookup_challenge = builder.add_fiat_shamir_challenge_node(1);
120
121    // Intermediate layers
122
123    // Image decomposition layers
124    let image_strip_nodes = image_strip_reroutings
125        .into_iter()
126        .map(|reroutings| {
127            builder.add_identity_gate_node(
128                &to_reroute,
129                reroutings,
130                log2_ceil(IM_STRIP_ROWS * IM_STRIP_COLS) as usize,
131                None,
132            )
133        })
134        .collect_vec();
135
136    // Image RLC layer
137    let image_rlc = ZkIriscodeComponent::sum_of_products(
138        &mut builder,
139        rlc_challenges_generic.iter().collect(),
140        image_strip_nodes.iter().collect(),
141    );
142
143    // Reroute the image to the LH matrix multiplicand
144    let rerouted_image = builder.add_identity_gate_node(
145        &image_rlc,
146        lh_matrix_reroutings,
147        MATMULT_ROWS_NUM_VARS + MATMULT_INTERNAL_DIM_NUM_VARS,
148        None,
149    );
150
151    // Matmult layer
152    let matmult = builder.add_matmult_node(
153        &rerouted_image,
154        (MATMULT_ROWS_NUM_VARS, MATMULT_INTERNAL_DIM_NUM_VARS),
155        &rh_matmult_multiplicand,
156        (MATMULT_INTERNAL_DIM_NUM_VARS, MATMULT_COLS_NUM_VARS),
157    );
158
159    // Thresholds RLC layer
160    let to_sub_from_matmult_splits = builder.add_split_node(&to_sub_from_matmult, log_num_strips);
161
162    let to_sub_from_matmult_rlc = ZkIriscodeComponent::sum_of_products(
163        &mut builder,
164        rlc_challenges_generic.iter().collect(),
165        to_sub_from_matmult_splits.iter().collect(),
166    );
167
168    // Subtract the thresholds from the result of matmult
169    let subtractor = builder.add_sector(matmult - to_sub_from_matmult_rlc);
170
171    // Create an RLC node for each of the NUM_DIGITS digital places
172    let digits_split_nodes = digits_input_shreds
173        .iter()
174        .map(|shred| builder.add_split_node(shred, log_num_strips))
175        .collect_vec();
176    let digits_rlc = digits_split_nodes
177        .iter()
178        .map(|splits| {
179            let digit_rlc = ZkIriscodeComponent::sum_of_products(
180                &mut builder,
181                rlc_challenges_generic.iter().collect(),
182                splits.iter().collect(),
183            );
184            digit_rlc
185        })
186        .collect_vec();
187
188    // Concatenate the digits (which are stored for each digital place separately) into a single
189    // MLE for the lookup
190    let digits_concatenator = DigitComponents::digits_concatenator(
191        &mut builder,
192        &digits_input_shreds.iter().collect_vec(),
193    );
194
195    // Lookup table and constraint
196    let lookup_table = builder.add_lookup_table(&lookup_table_values, &lookup_challenge);
197    // println!("{:?} = Lookup table", builder.get_id(&lookup_table));
198    let _lookup_constraint =
199        builder.add_lookup_constraint(&lookup_table, &digits_concatenator, &digit_multiplicities);
200    // println!("{:?} = Lookup constraint", lookup_constraint.id());
201
202    // Form the unsigned recomposition of the RLC'd digits
203    let unsigned_recomp = DigitComponents::unsigned_recomposition(
204        &mut builder,
205        &digits_rlc.iter().collect_vec(),
206        BASE,
207    );
208
209    // Iriscode RLC layer
210    let sign_bits_splits = builder.add_split_node(&sign_bits, log_num_strips);
211
212    let sign_bits_rlc = ZkIriscodeComponent::sum_of_products(
213        &mut builder,
214        rlc_challenges_generic.iter().collect(),
215        sign_bits_splits.iter().collect(),
216    );
217
218    // Complementary recomp check using the unsigned recomp of the RLC'd digits and the RLC'd sign bits
219    let complementary_checker = DigitComponents::complementary_recomp_check(
220        &mut builder,
221        &subtractor,
222        &sign_bits_rlc,
223        &unsigned_recomp,
224        BASE,
225        NUM_DIGITS,
226    );
227    builder.set_output(&complementary_checker);
228
229    let bits_are_binary = DigitComponents::bits_are_binary(&mut builder, &sign_bits);
230    builder.set_output(&bits_are_binary);
231
232    // Generate the circuit description and input builder
233    builder.build_without_layer_combination()
234}
235
236pub fn iriscode_ss_attach_aux_data<F: Field, const BASE: u64>(
237    mut circuit: Circuit<F>,
238    iriscode_aux_data: IriscodeCircuitAuxData<F>,
239) -> Result<Circuit<F>> {
240    circuit.set_input(
241        V3_RH_MATMULT_SHRED,
242        iriscode_aux_data.rh_matmult_multiplicand,
243    );
244
245    circuit.set_input(
246        V3_TO_SUB_MATMULT_SHRED,
247        iriscode_aux_data.to_sub_from_matmult,
248    );
249
250    circuit.set_input(
251        V3_LOOKUP_SHRED,
252        MultilinearExtension::new((0..BASE).map(F::from).collect()),
253    );
254
255    Ok(circuit)
256}
257
258/// Generates a mapping from Layer IDs to their respective MLEs,
259/// by attaching the `iriscode_data` onto a circuit that is
260/// described through the `input_builder_metadata`.
261pub fn iriscode_ss_attach_input_data<F: Field, const BASE: u64>(
262    mut circuit: Circuit<F>,
263    iriscode_input_data: IriscodeCircuitInputData<F>,
264    iriscode_aux_data: IriscodeCircuitAuxData<F>,
265) -> Result<Circuit<F>> {
266    circuit.set_input(V3_INPUT_IMAGE_SHRED, iriscode_input_data.to_reroute);
267    circuit.set_input(
268        V3_RH_MATMULT_SHRED,
269        iriscode_aux_data.rh_matmult_multiplicand,
270    );
271
272    iriscode_input_data
273        .digits
274        .into_iter()
275        .enumerate()
276        .for_each(|(i, mle)| {
277            circuit.set_input(&format!("{V3_DIGITS_SHRED_TEMPLATE} {i}"), mle);
278        });
279
280    circuit.set_input(V3_SIGN_BITS_SHRED, iriscode_input_data.sign_bits);
281    circuit.set_input(
282        V3_TO_SUB_MATMULT_SHRED,
283        iriscode_aux_data.to_sub_from_matmult,
284    );
285    circuit.set_input(
286        V3_DIGITS_MULTIPLICITIES_SHRED,
287        iriscode_input_data.digit_multiplicities,
288    );
289    circuit.set_input(
290        V3_LOOKUP_SHRED,
291        MultilinearExtension::new((0..BASE).map(F::from).collect()),
292    );
293
294    Ok(circuit)
295}