frontend/zk_iriscode_ss/
data.rs

1use crate::digits::{complementary_decomposition, digits_to_field, to_slice_of_mles};
2use itertools::Itertools;
3use ndarray::{Array, Array2};
4use remainder::mle::evals::MultilinearExtension;
5use remainder::utils::arithmetic::i64_to_field;
6use remainder::utils::mle::pad_with;
7use serde::{Deserialize, Serialize};
8use shared_types::Field;
9
10/// Input data for the Worldcoin iriscode circuit.
11#[derive(Debug, Clone)]
12pub struct IriscodeCircuitInputData<F: Field> {
13    /// The values to be re-routed to form the LH multiplicand of the matrix multiplication.
14    /// Length is a power of two.
15    pub to_reroute: MultilinearExtension<F>,
16    /// The digits of the complementary digital decompositions (base BASE) of matmult minus `to_sub_from_matmult`.
17    /// Length of each MLE is `1 << (MATMULT_ROWS_NUM_VARS + MATMULT_COLS_NUM_VARS)`.
18    pub digits: Vec<MultilinearExtension<F>>,
19    /// The bits of the complementary digital decompositions of the values
20    ///     matmult - to_sub_from_matmult.
21    /// (This is the iris code (if processing the iris image) or the mask code (if processing the mask).)
22    /// Length is `1 << (MATMULT_ROWS_NUM_VARS + MATMULT_COLS_NUM_VARS)`.
23    pub sign_bits: MultilinearExtension<F>,
24    /// The number of times each digit 0 .. BASE - 1 occurs in the complementary digital decompositions of
25    /// response - threshold.
26    /// Length is `BASE`.
27    pub digit_multiplicities: MultilinearExtension<F>,
28}
29
30/// Auxiliary input data for the Worldcoin iriscode circuit.
31#[derive(Debug, Clone, Serialize, Deserialize)]
32#[serde(bound = "F: Field")]
33pub struct IriscodeCircuitAuxData<F: Field> {
34    /// The MLE of the RH multiplicand of the matrix multiplication.
35    /// Length is `1 << (MATMULT_INTERNAL_DIM_NUM_VARS + MATMULT_COLS_NUM_VARS)`.
36    pub rh_matmult_multiplicand: MultilinearExtension<F>,
37
38    /// Values to be subtracted from the result of the matrix multiplication.
39    /// Length is `1 << (MATMULT_ROWS_NUM_VARS + MATMULT_COLS_NUM_VARS)`.
40    pub to_sub_from_matmult: MultilinearExtension<F>,
41}
42
43/// Wirings are a Vec of 4-tuples of u16s; each tuple maps a coordinate of the source matrix to a coordinate of
44/// the destination matrix. This function returns the corresponding Vec of 2-tuples of
45/// usize, which are the re-routings of the 1d MLEs.
46/// Input order is `(src_row_idx, src_col_idx, dest_row_idx, dest_col_idx)`.
47/// Output order is `(dest_idx, src_idx)` (to match [remainder::layer::identity_gate::IdentityGate]).
48pub fn wirings_to_reroutings(
49    wirings: &[(u16, u16, u16, u16)],
50    src_arr_num_cols: usize,
51    dest_arr_num_cols: usize,
52) -> Vec<(u32, u32)> {
53    wirings
54        .iter()
55        .map(|row| {
56            let (src_row_idx, src_col_idx, dest_row_idx, dest_col_idx) = (
57                row.0 as usize,
58                row.1 as usize,
59                row.2 as usize,
60                row.3 as usize,
61            );
62            let src_idx = src_row_idx * src_arr_num_cols + src_col_idx;
63            let dest_idx = dest_row_idx * dest_arr_num_cols + dest_col_idx;
64            (dest_idx as u32, src_idx as u32)
65        })
66        .collect_vec()
67}
68
69pub fn build_iriscode_circuit_auxiliary_data<
70    F: Field,
71    const MATMULT_COLS_NUM_VARS: usize,
72    const MATMULT_INTERNAL_DIM_NUM_VARS: usize,
73    const NUM_STRIPS: usize,
74    const MAT_CHUNK_SIZE: usize,
75>(
76    rh_multiplicand: &[i32],
77    thresholds_matrix: &[i64],
78) -> IriscodeCircuitAuxData<F> {
79    // Build the RH multiplicand for the matmult.
80    let rh_multiplicand = Array2::from_shape_vec(
81        (
82            1 << MATMULT_INTERNAL_DIM_NUM_VARS,
83            1 << MATMULT_COLS_NUM_VARS,
84        ),
85        rh_multiplicand.iter().map(|&x| x as i64).collect_vec(),
86    )
87    .unwrap();
88
89    // Flatten the kernel values, convert to field.  (Already padded)
90    let rh_matmult_multiplicand: Vec<F> =
91        rh_multiplicand.into_iter().map(i64_to_field).collect_vec();
92
93    // Build the thresholds matrix from the 1d serialization.
94    let thresholds_matrix = Array2::from_shape_vec(
95        (NUM_STRIPS * MAT_CHUNK_SIZE, 1 << MATMULT_COLS_NUM_VARS),
96        thresholds_matrix.to_vec(),
97    )
98    .unwrap();
99
100    // Flatten the thresholds matrix, convert to field and pad.
101    let thresholds_matrix: Vec<F> = pad_with(
102        F::ZERO,
103        &thresholds_matrix
104            .into_iter()
105            .map(i64_to_field)
106            .collect_vec(),
107    );
108
109    IriscodeCircuitAuxData {
110        rh_matmult_multiplicand: MultilinearExtension::new(rh_matmult_multiplicand),
111        to_sub_from_matmult: MultilinearExtension::new(thresholds_matrix),
112    }
113}
114
115/// Build an instance of [IriscodeCircuitInputData] from the given image, RH multiplicand, thresholds and
116/// wiring data, by deriving the iris code.
117pub fn build_iriscode_circuit_data<
118    F: Field,
119    const IM_STRIP_ROWS: usize,
120    const IM_STRIP_COLS: usize,
121    const MATMULT_ROWS_NUM_VARS: usize,
122    const MATMULT_COLS_NUM_VARS: usize,
123    const MATMULT_INTERNAL_DIM_NUM_VARS: usize,
124    const BASE: u64,
125    const NUM_DIGITS: usize,
126>(
127    image: Array2<u8>,
128    rh_multiplicand: &[i32],
129    thresholds_matrix: &[i64],
130    image_strip_wirings: Vec<Vec<(u16, u16, u16, u16)>>,
131    lh_matrix_wirings: &[(u16, u16, u16, u16)],
132) -> IriscodeCircuitInputData<F> {
133    assert!(BASE.is_power_of_two());
134    assert!(NUM_DIGITS.is_power_of_two());
135    let num_strips = image_strip_wirings.len();
136
137    // Calculate the left-hand side of the matrix multiplication
138    let mat_chunk_size = 1 << MATMULT_ROWS_NUM_VARS;
139    let mut rerouted_matrix: Array2<i64> = Array::zeros((
140        num_strips * mat_chunk_size,
141        (1 << MATMULT_INTERNAL_DIM_NUM_VARS),
142    ));
143    image_strip_wirings
144        .iter()
145        .enumerate()
146        .for_each(|(strip_idx, wirings)| {
147            // Build the image strip
148            let mut image_strip: Array2<i64> = Array::zeros((IM_STRIP_ROWS, IM_STRIP_COLS));
149            wirings.iter().for_each(|row| {
150                let (im_row, im_col, im_strip_row, im_strip_col) = (
151                    row.0 as usize,
152                    row.1 as usize,
153                    row.2 as usize,
154                    row.3 as usize,
155                );
156                image_strip[[im_strip_row, im_strip_col]] = image[[im_row, im_col]] as i64;
157            });
158            // Route the image strip into the (un RLC'd) LH matrix
159            lh_matrix_wirings.iter().for_each(|row| {
160                let (im_strip_row, im_strip_col, mat_row, mat_col) = (
161                    row.0 as usize,
162                    row.1 as usize,
163                    row.2 as usize,
164                    row.3 as usize,
165                );
166                rerouted_matrix[[strip_idx * mat_chunk_size + mat_row, mat_col]] =
167                    image_strip[[im_strip_row, im_strip_col]];
168            });
169        });
170
171    // Build the RH multiplicand for the matmult.
172    let rh_multiplicand = Array2::from_shape_vec(
173        (
174            1 << MATMULT_INTERNAL_DIM_NUM_VARS,
175            1 << MATMULT_COLS_NUM_VARS,
176        ),
177        rh_multiplicand.iter().map(|&x| x as i64).collect_vec(),
178    )
179    .unwrap();
180
181    // Build the thresholds matrix from the 1d serialization.
182    let thresholds_matrix = Array2::from_shape_vec(
183        (num_strips * mat_chunk_size, 1 << MATMULT_COLS_NUM_VARS),
184        thresholds_matrix.to_vec(),
185    )
186    .unwrap();
187
188    // Calculate the matrix product. Has dimensions (1 << MATMULT_ROWS_NUM_VARS, 1 << MATMULT_COLS_NUM_VARS).
189    let responses = rerouted_matrix.dot(&rh_multiplicand);
190
191    // Calculate the thresholded responses, which are the responses minus the thresholds. We pad
192    // the thresholded responses to the nearest power of two, since logup expects the number of
193    // constrained values (which will be the digits of the decomps of the threshold responses)
194    // to be a power of two.
195    let thres_resp = pad_with(
196        0,
197        &(responses - &thresholds_matrix).into_iter().collect_vec(),
198    );
199
200    // Calculate the complementary digital decompositions of the thresholded responses.
201    // Both vectors have the same length as thres_resp.
202    let (digits, code): (Vec<_>, Vec<_>) = thres_resp
203        .into_iter()
204        .map(|value| complementary_decomposition::<BASE, NUM_DIGITS>(value).unwrap())
205        .unzip();
206
207    // Count the number of times each digit occurs.
208    let mut digit_multiplicities: Vec<usize> = vec![0; BASE as usize];
209    digits.iter().for_each(|decomp| {
210        decomp.iter().for_each(|&digit| {
211            digit_multiplicities[digit as usize] += 1;
212        })
213    });
214
215    // Derive the padded image MLE.
216    // Note that this padding has nothing to do with the padding of the thresholded responses.
217    let image_matrix_mle: Vec<F> = pad_with(0, &image.into_iter().collect_vec())
218        .into_iter()
219        .map(|v| F::from(v as u64))
220        .collect_vec();
221
222    // Convert the iris code to field elements (this is already padded by construction).
223    let code: Vec<F> = code
224        .into_iter()
225        .map(|elem| F::from(elem as u64))
226        .collect_vec();
227
228    // Convert the digit multiplicities to field elements.
229    let digit_multiplicities = digit_multiplicities
230        .into_iter()
231        .map(|x| F::from(x as u64))
232        .collect_vec();
233    let digits = to_slice_of_mles(digits.iter().map(digits_to_field).collect_vec()).to_vec();
234
235    IriscodeCircuitInputData {
236        to_reroute: MultilinearExtension::new(image_matrix_mle),
237        digits,
238        sign_bits: MultilinearExtension::new(code),
239        digit_multiplicities: MultilinearExtension::new(digit_multiplicities),
240    }
241}