1use std::collections::HashMap;
2
3use ark_std::log2;
4use itertools::Itertools;
5use rand::Rng;
6use shared_types::Field;
7
8use remainder::{
9 layer::{
10 gate::{compute_gate_data_outputs, BinaryOperation},
11 matmult::product_two_matrices_from_flattened_vectors,
12 },
13 mle::evals::MultilinearExtension,
14};
15
16use crate::{
17 hyrax_worldcoin_mpc::mpc_prover::MPCCircuitConstData,
18 worldcoin_mpc::parameters::{
19 EVALUATION_POINTS_U64, GR4_MULTIPLICATION_WIRINGS, TEST_MASKED_IRIS_CODES,
20 TEST_RANDOMNESSES, TEST_SHARES,
21 },
22};
23
24use super::parameters::{ENCODING_MATRIX_U64_TRANSPOSE, GR4_MODULUS};
25
26#[derive(Debug, Clone)]
28pub struct MPCCircuitInputData<F: Field> {
29 pub iris_codes: MultilinearExtension<F>,
33
34 pub masks: MultilinearExtension<F>,
38
39 pub slopes: MultilinearExtension<F>,
42
43 pub quotients: MultilinearExtension<F>,
52
53 pub shares_reduced_modulo_gr4_modulus: MultilinearExtension<F>,
58
59 pub multiplicities_shares: MultilinearExtension<F>,
62
63 pub multiplicities_slopes: MultilinearExtension<F>,
66}
67
68pub fn gen_mpc_evaluation_points<
69 F: Field,
70 const NUM_IRIS_4_CHUNKS: usize,
71 const PARTY_IDX: usize,
72>() -> MultilinearExtension<F> {
73 MultilinearExtension::new(
74 EVALUATION_POINTS_U64[PARTY_IDX]
75 .into_iter()
76 .map(|x| F::from(x))
77 .cycle()
78 .take(NUM_IRIS_4_CHUNKS * 4)
79 .collect(),
80 )
81}
82
83pub fn gen_mpc_encoding_matrix<F: Field, const NUM_IRIS_4_CHUNKS: usize>() -> MultilinearExtension<F>
84{
85 MultilinearExtension::new(
86 ENCODING_MATRIX_U64_TRANSPOSE
87 .into_iter()
88 .map(|x| F::from(x))
89 .collect_vec(),
90 )
91}
92
93pub fn gen_mpc_common_aux_data<F: Field, const NUM_IRIS_4_CHUNKS: usize, const PARTY_IDX: usize>(
94) -> MPCCircuitConstData<F> {
95 let evaluation_points = gen_mpc_evaluation_points::<F, NUM_IRIS_4_CHUNKS, PARTY_IDX>();
96 let encoding_matrix = gen_mpc_encoding_matrix::<F, NUM_IRIS_4_CHUNKS>();
97 let lookup_table_values = MultilinearExtension::new((0..GR4_MODULUS).map(F::from).collect());
98
99 MPCCircuitConstData {
100 evaluation_points,
101 encoding_matrix,
102 lookup_table_values,
103 }
104}
105
106#[allow(clippy::type_complexity)]
111pub fn gen_mpc_input_data<F: Field, const NUM_IRIS_4_CHUNKS: usize>(
112 iris_codes: &MultilinearExtension<F>,
113 masks: &MultilinearExtension<F>,
114 slopes: &MultilinearExtension<F>,
115 encoding_matrix: &MultilinearExtension<F>,
116 evaluation_points: &MultilinearExtension<F>,
117) -> MPCCircuitInputData<F> {
118 let num_copies = NUM_IRIS_4_CHUNKS;
119
120 let mut masked_iris_codes = iris_codes
124 .iter()
125 .zip(masks.iter())
126 .map(|(iris_code, mask)| F::from(2).neg() * iris_code - mask + F::from(GR4_MODULUS))
127 .collect_vec();
128
129 let encoded_masked_iris_code = product_two_matrices_from_flattened_vectors(
131 &masked_iris_codes,
132 &encoding_matrix.to_vec(),
133 num_copies,
134 4,
135 4,
136 4,
137 );
138
139 let evaluation_points_times_slopes = compute_gate_data_outputs(
140 GR4_MULTIPLICATION_WIRINGS.to_vec(),
141 log2(num_copies.next_power_of_two()) as usize,
142 evaluation_points,
143 slopes,
144 BinaryOperation::Mul,
145 );
146
147 let mut shares_before_modulo_gr4 = encoded_masked_iris_code
148 .into_iter()
149 .zip(evaluation_points_times_slopes.iter())
150 .map(|(a, b)| a + b)
151 .collect_vec();
152
153 let (quotients, expected_shares): (Vec<F>, Vec<F>) = shares_before_modulo_gr4
156 .clone()
157 .into_iter()
158 .map(|x| {
159 let mut bytes = x.to_bytes_le();
160 let mut without_first_two_bytes = bytes.split_off(2);
161
162 without_first_two_bytes.append(&mut [0u8, 0u8].to_vec());
164
165 bytes.append(&mut [0u8; 30].to_vec());
167
168 (
169 F::from_bytes_le(&without_first_two_bytes),
170 F::from_bytes_le(&bytes),
171 )
172 })
173 .unzip();
174
175 let f_gr4_modulus = F::from(GR4_MODULUS);
176
177 let mut counts_shares: HashMap<F, u64> = HashMap::new();
179 expected_shares.iter().for_each(|x| {
180 assert!(x < &f_gr4_modulus);
182
183 *counts_shares.entry(*x).or_insert(0) += 1;
184 });
185
186 let mut multiplicities_shares = vec![F::ZERO; GR4_MODULUS as usize];
187 counts_shares.iter().for_each(|(k, v)| {
188 multiplicities_shares[k.to_u64s_le()[0] as usize] = F::from(*v);
189 });
190 let num_elements = num_copies * 4;
192 let num_zeros = num_elements.next_power_of_two() - num_elements;
193 multiplicities_shares[0] += F::from(num_zeros as u64);
194
195 let mut counts_slopes: HashMap<F, u64> = HashMap::new();
197 slopes.iter().for_each(|x| {
198 assert!(x < f_gr4_modulus);
200
201 *counts_slopes.entry(x).or_insert(0) += 1;
202 });
203
204 let mut multiplicities_slopes = vec![F::ZERO; GR4_MODULUS as usize];
205 counts_slopes.iter().for_each(|(k, v)| {
206 multiplicities_slopes[k.to_u64s_le()[0] as usize] = F::from(*v);
207 });
208 let num_elements = num_copies * 4;
210 let num_zeros = num_elements.next_power_of_two() - num_elements;
211 multiplicities_slopes[0] += F::from(num_zeros as u64);
212
213 quotients
214 .iter()
215 .zip(shares_before_modulo_gr4.iter())
216 .zip(expected_shares.iter())
217 .for_each(|((quotient, share_before_modulo_gr4), expected_share)| {
218 assert_eq!(
219 *quotient * F::from(GR4_MODULUS) + expected_share,
220 *share_before_modulo_gr4
221 );
222 });
223
224 for f in shares_before_modulo_gr4.iter_mut() {
226 f.zeroize();
227 }
228 for f in masked_iris_codes.iter_mut() {
230 f.zeroize();
231 }
232
233 let quotients = MultilinearExtension::new(quotients);
234 let shares_reduced_modulo_gr4_modulus = MultilinearExtension::new(expected_shares);
235 let multiplicities_shares = MultilinearExtension::new(multiplicities_shares);
236 let multiplicities_slopes = MultilinearExtension::new(multiplicities_slopes);
237
238 MPCCircuitInputData::<F> {
239 iris_codes: iris_codes.clone(),
240 masks: masks.clone(),
241 slopes: slopes.clone(),
242 quotients,
243 shares_reduced_modulo_gr4_modulus,
244 multiplicities_shares,
245 multiplicities_slopes,
246 }
248}
249
250pub fn generate_trivial_test_data<
252 F: Field,
253 const NUM_IRIS_4_CHUNKS: usize,
254 const PARTY_IDX: usize,
255>() -> (MPCCircuitConstData<F>, MPCCircuitInputData<F>) {
256 let num_copies = NUM_IRIS_4_CHUNKS;
257 let mut rng = rand::thread_rng();
258
259 let iris_codes = MultilinearExtension::new(
260 (0..4 * num_copies)
261 .map(|_| F::from(rng.gen_range(0..=1)))
262 .collect(),
263 );
264 let masks = MultilinearExtension::new(
265 (0..4 * num_copies)
266 .map(|_| F::from(rng.gen_range(0..=1)))
267 .collect(),
268 );
269 let slopes = MultilinearExtension::new(
270 (0..4 * num_copies)
271 .map(|_| F::from(rng.gen_range(0..=(GR4_MODULUS - 1))))
272 .collect(),
273 );
274
275 let mpc_aux_data = gen_mpc_common_aux_data::<F, NUM_IRIS_4_CHUNKS, PARTY_IDX>();
276
277 let mpc_input_data = gen_mpc_input_data::<F, NUM_IRIS_4_CHUNKS>(
278 &iris_codes,
279 &masks,
280 &slopes,
281 &mpc_aux_data.encoding_matrix,
282 &mpc_aux_data.evaluation_points,
283 );
284
285 assert_eq!(mpc_input_data.quotients.len(), num_copies * 4);
286 assert_eq!(
287 mpc_input_data.shares_reduced_modulo_gr4_modulus.len(),
288 num_copies * 4
289 );
290 assert_eq!(slopes.len(), num_copies * 4);
291 assert_eq!(mpc_aux_data.evaluation_points.len(), num_copies * 4);
292 assert_eq!(
293 mpc_input_data.multiplicities_shares.len(),
294 GR4_MODULUS as usize
295 );
296 assert_eq!(
297 mpc_input_data.multiplicities_slopes.len(),
298 GR4_MODULUS as usize
299 );
300 assert_eq!(mpc_aux_data.lookup_table_values.len(), GR4_MODULUS as usize);
301
302 (mpc_aux_data, mpc_input_data)
303}
304
305pub fn fetch_inversed_test_data<
308 F: Field,
309 const NUM_IRIS_4_CHUNKS: usize,
310 const PARTY_IDX: usize,
311>(
312 test_idx: usize,
313) -> (MPCCircuitConstData<F>, MPCCircuitInputData<F>) {
314 let num_copies = NUM_IRIS_4_CHUNKS;
315 if test_idx + NUM_IRIS_4_CHUNKS >= TEST_MASKED_IRIS_CODES.len() {
316 panic!("test_idx out of range");
317 }
318 let mut rng = rand::thread_rng();
319
320 let masked_iris_codes = MultilinearExtension::new(
321 (0..num_copies)
322 .flat_map(|batch_idx| {
323 TEST_MASKED_IRIS_CODES[batch_idx + test_idx]
324 .into_iter()
325 .map(F::from)
326 .collect::<Vec<F>>()
327 })
328 .collect_vec(),
329 );
330 let iris_codes = MultilinearExtension::new(
331 (0..num_copies * 4)
332 .map(|_| F::from(rng.gen_range(0..=1)))
333 .collect(),
334 );
335 assert_eq!(masked_iris_codes.len(), iris_codes.len());
336 let masks = MultilinearExtension::new(
337 masked_iris_codes
338 .iter()
339 .zip(iris_codes.iter())
340 .map(|(masked_iris_code, iris_code)| F::from(2).neg() * iris_code - masked_iris_code)
341 .collect(),
342 );
343 let slopes = MultilinearExtension::new(
344 (0..num_copies)
345 .flat_map(|batch_idx| {
346 TEST_RANDOMNESSES[batch_idx + test_idx]
347 .into_iter()
348 .map(F::from)
349 .collect::<Vec<F>>()
350 })
351 .collect_vec(),
352 );
353
354 let mpc_aux_data = gen_mpc_common_aux_data::<F, NUM_IRIS_4_CHUNKS, PARTY_IDX>();
355 let mpc_input_data = gen_mpc_input_data::<F, NUM_IRIS_4_CHUNKS>(
356 &iris_codes,
357 &masks,
358 &slopes,
359 &mpc_aux_data.encoding_matrix,
360 &mpc_aux_data.evaluation_points,
361 );
362
363 assert_eq!(mpc_input_data.quotients.len(), num_copies * 4);
364 assert_eq!(
365 mpc_input_data.shares_reduced_modulo_gr4_modulus.len(),
366 num_copies * 4
367 );
368 assert_eq!(slopes.len(), num_copies * 4);
369 assert_eq!(mpc_aux_data.evaluation_points.len(), num_copies * 4);
370 assert_eq!(
371 mpc_input_data.multiplicities_shares.len(),
372 GR4_MODULUS as usize
373 );
374 assert_eq!(
375 mpc_input_data.multiplicities_slopes.len(),
376 GR4_MODULUS as usize
377 );
378 assert_eq!(mpc_aux_data.lookup_table_values.len(), GR4_MODULUS as usize);
379
380 mpc_input_data
381 .shares_reduced_modulo_gr4_modulus
382 .iter()
383 .zip(
384 (0..num_copies)
385 .flat_map(|batch_idx| TEST_SHARES[PARTY_IDX][batch_idx + test_idx].into_iter())
386 .collect_vec()
387 .iter(),
388 )
389 .for_each(|(a, b)| {
390 assert_eq!(a, F::from(*b));
391 });
392
393 (mpc_aux_data, mpc_input_data)
394}