frontend/worldcoin_mpc/
data.rs

1use std::collections::HashMap;
2
3use ark_std::log2;
4use itertools::Itertools;
5use rand::Rng;
6use shared_types::Field;
7
8use remainder::{
9    layer::{
10        gate::{compute_gate_data_outputs, BinaryOperation},
11        matmult::product_two_matrices_from_flattened_vectors,
12    },
13    mle::evals::MultilinearExtension,
14};
15
16use crate::{
17    hyrax_worldcoin_mpc::mpc_prover::MPCCircuitConstData,
18    worldcoin_mpc::parameters::{
19        EVALUATION_POINTS_U64, GR4_MULTIPLICATION_WIRINGS, TEST_MASKED_IRIS_CODES,
20        TEST_RANDOMNESSES, TEST_SHARES,
21    },
22};
23
24use super::parameters::{ENCODING_MATRIX_U64_TRANSPOSE, GR4_MODULUS};
25
26/// Used for instantiating the mpc circuit.
27#[derive(Debug, Clone)]
28pub struct MPCCircuitInputData<F: Field> {
29    /// The iris codes, they are {0, 1} valued.
30    /// Needed to calculate the masked code which we secret share
31    /// between the three parties.
32    pub iris_codes: MultilinearExtension<F>,
33
34    /// The masks, they are {0, 1} valued.
35    /// Needed to calculate the masked code which we secret share
36    /// between the three parties.
37    pub masks: MultilinearExtension<F>,
38
39    /// The slopes, they are random elements in GR4
40    /// These are generated randomly beforehand and supplied into the circuit
41    pub slopes: MultilinearExtension<F>,
42
43    /// The quotients, they are elements in GR4: GR(size_of(F), 4) is a Galois
44    /// extension of Z/size_of(F)Z over the monic polynomial x^4 - x - 1
45    /// The naively calculated secret shares might be outside of the range of
46    /// [0..2^16], which are the range of the coefficients of the particular GR4
47    /// we choose GR(2^16, 4)
48    /// Therefore, we supply `quotients` and `shares_reduced_modulo_gr4_modulus`,
49    /// so that `shares_reduced_modulo_gr4_modulus` + `quotients` * `GR4_MODULUS`
50    /// equals to the naively calculated secret shares
51    pub quotients: MultilinearExtension<F>,
52
53    /// The shares_reduced_modulo_gr4_modulus, they are elemnts in GR4: GR(2^16, 4)
54    /// We supply `quotients` and `shares_reduced_modulo_gr4_modulus`,
55    /// so that `shares_reduced_modulo_gr4_modulus` + `quotients` * `GR4_MODULUS`
56    /// equals to the naively calculated secret shares
57    pub shares_reduced_modulo_gr4_modulus: MultilinearExtension<F>,
58
59    /// The multiplicies is used for lookup (expected_shares), we calculate the occurances
60    /// of different numbers between [0..2^16]
61    pub multiplicities_shares: MultilinearExtension<F>,
62
63    /// The multiplicies is used for lookup (slope), we calculate the occurances of different
64    /// numbers between [0..2^16]
65    pub multiplicities_slopes: MultilinearExtension<F>,
66}
67
68pub fn gen_mpc_evaluation_points<
69    F: Field,
70    const NUM_IRIS_4_CHUNKS: usize,
71    const PARTY_IDX: usize,
72>() -> MultilinearExtension<F> {
73    MultilinearExtension::new(
74        EVALUATION_POINTS_U64[PARTY_IDX]
75            .into_iter()
76            .map(|x| F::from(x))
77            .cycle()
78            .take(NUM_IRIS_4_CHUNKS * 4)
79            .collect(),
80    )
81}
82
83pub fn gen_mpc_encoding_matrix<F: Field, const NUM_IRIS_4_CHUNKS: usize>() -> MultilinearExtension<F>
84{
85    MultilinearExtension::new(
86        ENCODING_MATRIX_U64_TRANSPOSE
87            .into_iter()
88            .map(|x| F::from(x))
89            .collect_vec(),
90    )
91}
92
93pub fn gen_mpc_common_aux_data<F: Field, const NUM_IRIS_4_CHUNKS: usize, const PARTY_IDX: usize>(
94) -> MPCCircuitConstData<F> {
95    let evaluation_points = gen_mpc_evaluation_points::<F, NUM_IRIS_4_CHUNKS, PARTY_IDX>();
96    let encoding_matrix = gen_mpc_encoding_matrix::<F, NUM_IRIS_4_CHUNKS>();
97    let lookup_table_values = MultilinearExtension::new((0..GR4_MODULUS).map(F::from).collect());
98
99    MPCCircuitConstData {
100        evaluation_points,
101        encoding_matrix,
102        lookup_table_values,
103    }
104}
105
106/// Selects the encoding matrix, and
107/// Calculates the quotients, expected_shares (modulo gr4), and the multiplicities
108/// returns as a tuple, in the folllowing order:
109/// (encoding_matrix, quotients, expected_shares, multiplicities)
110#[allow(clippy::type_complexity)]
111pub fn gen_mpc_input_data<F: Field, const NUM_IRIS_4_CHUNKS: usize>(
112    iris_codes: &MultilinearExtension<F>,
113    masks: &MultilinearExtension<F>,
114    slopes: &MultilinearExtension<F>,
115    encoding_matrix: &MultilinearExtension<F>,
116    evaluation_points: &MultilinearExtension<F>,
117) -> MPCCircuitInputData<F> {
118    let num_copies = NUM_IRIS_4_CHUNKS;
119
120    // let lookup_table_values = MultilinearExtension::new((0..GR4_MODULUS).map(F::from).collect());
121
122    // masked_iris_codes dimension is: (NUM_IRIS_4_CHUNKS, 4)
123    let mut masked_iris_codes = iris_codes
124        .iter()
125        .zip(masks.iter())
126        .map(|(iris_code, mask)| F::from(2).neg() * iris_code - mask + F::from(GR4_MODULUS))
127        .collect_vec();
128
129    // encoding_matrix dimension is: (4, 4)
130    let encoded_masked_iris_code = product_two_matrices_from_flattened_vectors(
131        &masked_iris_codes,
132        &encoding_matrix.to_vec(),
133        num_copies,
134        4,
135        4,
136        4,
137    );
138
139    let evaluation_points_times_slopes = compute_gate_data_outputs(
140        GR4_MULTIPLICATION_WIRINGS.to_vec(),
141        log2(num_copies.next_power_of_two()) as usize,
142        evaluation_points,
143        slopes,
144        BinaryOperation::Mul,
145    );
146
147    let mut shares_before_modulo_gr4 = encoded_masked_iris_code
148        .into_iter()
149        .zip(evaluation_points_times_slopes.iter())
150        .map(|(a, b)| a + b)
151        .collect_vec();
152
153    // because the modulo is 2^16, we can just take the smallest 16 bits as the
154    // reduced modulo gr4 shares, and the rest bits as the quotients
155    let (quotients, expected_shares): (Vec<F>, Vec<F>) = shares_before_modulo_gr4
156        .clone()
157        .into_iter()
158        .map(|x| {
159            let mut bytes = x.to_bytes_le();
160            let mut without_first_two_bytes = bytes.split_off(2);
161
162            // for quotient: pads the rest two zero bytes at the end
163            without_first_two_bytes.append(&mut [0u8, 0u8].to_vec());
164
165            // for remainder(modulus): pads the rest 30 zero bytes
166            bytes.append(&mut [0u8; 30].to_vec());
167
168            (
169                F::from_bytes_le(&without_first_two_bytes),
170                F::from_bytes_le(&bytes),
171            )
172        })
173        .unzip();
174
175    let f_gr4_modulus = F::from(GR4_MODULUS);
176
177    // calculates the multiplicities of shares
178    let mut counts_shares: HashMap<F, u64> = HashMap::new();
179    expected_shares.iter().for_each(|x| {
180        // check that indeed the shares are less than the modulus
181        assert!(x < &f_gr4_modulus);
182
183        *counts_shares.entry(*x).or_insert(0) += 1;
184    });
185
186    let mut multiplicities_shares = vec![F::ZERO; GR4_MODULUS as usize];
187    counts_shares.iter().for_each(|(k, v)| {
188        multiplicities_shares[k.to_u64s_le()[0] as usize] = F::from(*v);
189    });
190    // number of 0s as implicit paddings
191    let num_elements = num_copies * 4;
192    let num_zeros = num_elements.next_power_of_two() - num_elements;
193    multiplicities_shares[0] += F::from(num_zeros as u64);
194
195    // the same process for slopes
196    let mut counts_slopes: HashMap<F, u64> = HashMap::new();
197    slopes.iter().for_each(|x| {
198        // check that indeed the shares are less than the modulus
199        assert!(x < f_gr4_modulus);
200
201        *counts_slopes.entry(x).or_insert(0) += 1;
202    });
203
204    let mut multiplicities_slopes = vec![F::ZERO; GR4_MODULUS as usize];
205    counts_slopes.iter().for_each(|(k, v)| {
206        multiplicities_slopes[k.to_u64s_le()[0] as usize] = F::from(*v);
207    });
208    // number of 0s as implicit paddings
209    let num_elements = num_copies * 4;
210    let num_zeros = num_elements.next_power_of_two() - num_elements;
211    multiplicities_slopes[0] += F::from(num_zeros as u64);
212
213    quotients
214        .iter()
215        .zip(shares_before_modulo_gr4.iter())
216        .zip(expected_shares.iter())
217        .for_each(|((quotient, share_before_modulo_gr4), expected_share)| {
218            assert_eq!(
219                *quotient * F::from(GR4_MODULUS) + expected_share,
220                *share_before_modulo_gr4
221            );
222        });
223
224    // zeroize shares_before_modulo_gr4
225    for f in shares_before_modulo_gr4.iter_mut() {
226        f.zeroize();
227    }
228    // zeroize masked_iris_codes
229    for f in masked_iris_codes.iter_mut() {
230        f.zeroize();
231    }
232
233    let quotients = MultilinearExtension::new(quotients);
234    let shares_reduced_modulo_gr4_modulus = MultilinearExtension::new(expected_shares);
235    let multiplicities_shares = MultilinearExtension::new(multiplicities_shares);
236    let multiplicities_slopes = MultilinearExtension::new(multiplicities_slopes);
237
238    MPCCircuitInputData::<F> {
239        iris_codes: iris_codes.clone(),
240        masks: masks.clone(),
241        slopes: slopes.clone(),
242        quotients,
243        shares_reduced_modulo_gr4_modulus,
244        multiplicities_shares,
245        multiplicities_slopes,
246        // lookup_table_values,
247    }
248}
249
250/// create test data for mpc circuits, control the size of such
251pub fn generate_trivial_test_data<
252    F: Field,
253    const NUM_IRIS_4_CHUNKS: usize,
254    const PARTY_IDX: usize,
255>() -> (MPCCircuitConstData<F>, MPCCircuitInputData<F>) {
256    let num_copies = NUM_IRIS_4_CHUNKS;
257    let mut rng = rand::thread_rng();
258
259    let iris_codes = MultilinearExtension::new(
260        (0..4 * num_copies)
261            .map(|_| F::from(rng.gen_range(0..=1)))
262            .collect(),
263    );
264    let masks = MultilinearExtension::new(
265        (0..4 * num_copies)
266            .map(|_| F::from(rng.gen_range(0..=1)))
267            .collect(),
268    );
269    let slopes = MultilinearExtension::new(
270        (0..4 * num_copies)
271            .map(|_| F::from(rng.gen_range(0..=(GR4_MODULUS - 1))))
272            .collect(),
273    );
274
275    let mpc_aux_data = gen_mpc_common_aux_data::<F, NUM_IRIS_4_CHUNKS, PARTY_IDX>();
276
277    let mpc_input_data = gen_mpc_input_data::<F, NUM_IRIS_4_CHUNKS>(
278        &iris_codes,
279        &masks,
280        &slopes,
281        &mpc_aux_data.encoding_matrix,
282        &mpc_aux_data.evaluation_points,
283    );
284
285    assert_eq!(mpc_input_data.quotients.len(), num_copies * 4);
286    assert_eq!(
287        mpc_input_data.shares_reduced_modulo_gr4_modulus.len(),
288        num_copies * 4
289    );
290    assert_eq!(slopes.len(), num_copies * 4);
291    assert_eq!(mpc_aux_data.evaluation_points.len(), num_copies * 4);
292    assert_eq!(
293        mpc_input_data.multiplicities_shares.len(),
294        GR4_MODULUS as usize
295    );
296    assert_eq!(
297        mpc_input_data.multiplicities_slopes.len(),
298        GR4_MODULUS as usize
299    );
300    assert_eq!(mpc_aux_data.lookup_table_values.len(), GR4_MODULUS as usize);
301
302    (mpc_aux_data, mpc_input_data)
303}
304
305/// Fetch one quadruplets from the test data given by Inversed,
306/// `test_idx` specifies which copy
307pub fn fetch_inversed_test_data<
308    F: Field,
309    const NUM_IRIS_4_CHUNKS: usize,
310    const PARTY_IDX: usize,
311>(
312    test_idx: usize,
313) -> (MPCCircuitConstData<F>, MPCCircuitInputData<F>) {
314    let num_copies = NUM_IRIS_4_CHUNKS;
315    if test_idx + NUM_IRIS_4_CHUNKS >= TEST_MASKED_IRIS_CODES.len() {
316        panic!("test_idx out of range");
317    }
318    let mut rng = rand::thread_rng();
319
320    let masked_iris_codes = MultilinearExtension::new(
321        (0..num_copies)
322            .flat_map(|batch_idx| {
323                TEST_MASKED_IRIS_CODES[batch_idx + test_idx]
324                    .into_iter()
325                    .map(F::from)
326                    .collect::<Vec<F>>()
327            })
328            .collect_vec(),
329    );
330    let iris_codes = MultilinearExtension::new(
331        (0..num_copies * 4)
332            .map(|_| F::from(rng.gen_range(0..=1)))
333            .collect(),
334    );
335    assert_eq!(masked_iris_codes.len(), iris_codes.len());
336    let masks = MultilinearExtension::new(
337        masked_iris_codes
338            .iter()
339            .zip(iris_codes.iter())
340            .map(|(masked_iris_code, iris_code)| F::from(2).neg() * iris_code - masked_iris_code)
341            .collect(),
342    );
343    let slopes = MultilinearExtension::new(
344        (0..num_copies)
345            .flat_map(|batch_idx| {
346                TEST_RANDOMNESSES[batch_idx + test_idx]
347                    .into_iter()
348                    .map(F::from)
349                    .collect::<Vec<F>>()
350            })
351            .collect_vec(),
352    );
353
354    let mpc_aux_data = gen_mpc_common_aux_data::<F, NUM_IRIS_4_CHUNKS, PARTY_IDX>();
355    let mpc_input_data = gen_mpc_input_data::<F, NUM_IRIS_4_CHUNKS>(
356        &iris_codes,
357        &masks,
358        &slopes,
359        &mpc_aux_data.encoding_matrix,
360        &mpc_aux_data.evaluation_points,
361    );
362
363    assert_eq!(mpc_input_data.quotients.len(), num_copies * 4);
364    assert_eq!(
365        mpc_input_data.shares_reduced_modulo_gr4_modulus.len(),
366        num_copies * 4
367    );
368    assert_eq!(slopes.len(), num_copies * 4);
369    assert_eq!(mpc_aux_data.evaluation_points.len(), num_copies * 4);
370    assert_eq!(
371        mpc_input_data.multiplicities_shares.len(),
372        GR4_MODULUS as usize
373    );
374    assert_eq!(
375        mpc_input_data.multiplicities_slopes.len(),
376        GR4_MODULUS as usize
377    );
378    assert_eq!(mpc_aux_data.lookup_table_values.len(), GR4_MODULUS as usize);
379
380    mpc_input_data
381        .shares_reduced_modulo_gr4_modulus
382        .iter()
383        .zip(
384            (0..num_copies)
385                .flat_map(|batch_idx| TEST_SHARES[PARTY_IDX][batch_idx + test_idx].into_iter())
386                .collect_vec()
387                .iter(),
388        )
389        .for_each(|(a, b)| {
390            assert_eq!(a, F::from(*b));
391        });
392
393    (mpc_aux_data, mpc_input_data)
394}