1use std::collections::HashMap;
2
3use ark_std::log2;
4use itertools::Itertools;
5use rand::Rng;
6use shared_types::Field;
7
8use remainder::{
9 layer::{
10 gate::{compute_gate_data_outputs, BinaryOperation},
11 matmult::product_two_matrices_from_flattened_vectors,
12 },
13 mle::evals::MultilinearExtension,
14};
15
16use crate::{
17 hyrax_worldcoin_mpc::mpc_prover::MPCCircuitConstData,
18 worldcoin_mpc::parameters::{
19 EVALUATION_POINTS_U64, GR4_MULTIPLICATION_WIRINGS, TEST_MASKED_IRIS_CODES,
20 TEST_RANDOMNESSES, TEST_SHARES,
21 },
22};
23
24use super::parameters::{ENCODING_MATRIX_U64_TRANSPOSE, GR4_MODULUS};
25
26#[derive(Debug, Clone)]
28pub struct MPCCircuitInputData<F: Field> {
29 pub iris_codes: MultilinearExtension<F>,
33
34 pub masks: MultilinearExtension<F>,
38
39 pub slopes: MultilinearExtension<F>,
42
43 pub quotients: MultilinearExtension<F>,
52
53 pub shares_reduced_modulo_gr4_modulus: MultilinearExtension<F>,
58
59 pub multiplicities_shares: MultilinearExtension<F>,
62
63 pub multiplicities_slopes: MultilinearExtension<F>,
66
67 pub multiplicities_quotients: MultilinearExtension<F>,
70}
71
72pub fn gen_mpc_evaluation_points<
73 F: Field,
74 const NUM_IRIS_4_CHUNKS: usize,
75 const PARTY_IDX: usize,
76>() -> MultilinearExtension<F> {
77 MultilinearExtension::new(
78 EVALUATION_POINTS_U64[PARTY_IDX]
79 .into_iter()
80 .map(|x| F::from(x))
81 .cycle()
82 .take(NUM_IRIS_4_CHUNKS * 4)
83 .collect(),
84 )
85}
86
87pub fn gen_mpc_encoding_matrix<F: Field, const NUM_IRIS_4_CHUNKS: usize>() -> MultilinearExtension<F>
88{
89 MultilinearExtension::new(
90 ENCODING_MATRIX_U64_TRANSPOSE
91 .into_iter()
92 .map(|x| F::from(x))
93 .collect_vec(),
94 )
95}
96
97pub fn gen_mpc_common_aux_data<F: Field, const NUM_IRIS_4_CHUNKS: usize, const PARTY_IDX: usize>(
98) -> MPCCircuitConstData<F> {
99 let evaluation_points = gen_mpc_evaluation_points::<F, NUM_IRIS_4_CHUNKS, PARTY_IDX>();
100 let encoding_matrix = gen_mpc_encoding_matrix::<F, NUM_IRIS_4_CHUNKS>();
101 let lookup_table_values = MultilinearExtension::new((0..GR4_MODULUS).map(F::from).collect());
102 let lookup_table_values_2_19 = MultilinearExtension::new((0..(1 << 19)).map(F::from).collect());
103
104 MPCCircuitConstData {
105 evaluation_points,
106 encoding_matrix,
107 lookup_table_values,
108 lookup_table_values_2_19,
109 }
110}
111
112#[allow(clippy::type_complexity)]
117pub fn gen_mpc_input_data<F: Field, const NUM_IRIS_4_CHUNKS: usize>(
118 iris_codes: &MultilinearExtension<F>,
119 masks: &MultilinearExtension<F>,
120 slopes: &MultilinearExtension<F>,
121 encoding_matrix: &MultilinearExtension<F>,
122 evaluation_points: &MultilinearExtension<F>,
123) -> MPCCircuitInputData<F> {
124 let num_copies = NUM_IRIS_4_CHUNKS;
125
126 let mut masked_iris_codes = iris_codes
130 .iter()
131 .zip(masks.iter())
132 .map(|(iris_code, mask)| F::from(2).neg() * iris_code - mask + F::from(GR4_MODULUS))
133 .collect_vec();
134
135 let encoded_masked_iris_code = product_two_matrices_from_flattened_vectors(
137 &masked_iris_codes,
138 &encoding_matrix.to_vec(),
139 num_copies,
140 4,
141 4,
142 4,
143 );
144
145 let evaluation_points_times_slopes = compute_gate_data_outputs(
146 GR4_MULTIPLICATION_WIRINGS.to_vec(),
147 log2(num_copies.next_power_of_two()) as usize,
148 evaluation_points,
149 slopes,
150 BinaryOperation::Mul,
151 );
152
153 let mut shares_before_modulo_gr4 = encoded_masked_iris_code
154 .into_iter()
155 .zip(evaluation_points_times_slopes.iter())
156 .map(|(a, b)| a + b)
157 .collect_vec();
158
159 let (quotients, expected_shares): (Vec<F>, Vec<F>) = shares_before_modulo_gr4
162 .clone()
163 .into_iter()
164 .map(|x| {
165 let mut bytes = x.to_bytes_le();
166 let mut without_first_two_bytes = bytes.split_off(2);
167
168 without_first_two_bytes.append(&mut [0u8, 0u8].to_vec());
170
171 bytes.append(&mut [0u8; 30].to_vec());
173
174 (
175 F::from_bytes_le(&without_first_two_bytes),
176 F::from_bytes_le(&bytes),
177 )
178 })
179 .unzip();
180
181 let f_gr4_modulus = F::from(GR4_MODULUS);
182
183 let mut counts_shares: HashMap<F, u64> = HashMap::new();
185 expected_shares.iter().for_each(|x| {
186 assert!(x < &f_gr4_modulus);
188
189 *counts_shares.entry(*x).or_insert(0) += 1;
190 });
191
192 let mut multiplicities_shares = vec![F::ZERO; GR4_MODULUS as usize];
193 counts_shares.iter().for_each(|(k, v)| {
194 multiplicities_shares[k.to_u64s_le()[0] as usize] = F::from(*v);
195 });
196 let num_elements = num_copies * 4;
198 let num_zeros = num_elements.next_power_of_two() - num_elements;
199 multiplicities_shares[0] += F::from(num_zeros as u64);
200
201 let mut counts_slopes: HashMap<F, u64> = HashMap::new();
203 slopes.iter().for_each(|x| {
204 assert!(x < f_gr4_modulus);
206
207 *counts_slopes.entry(x).or_insert(0) += 1;
208 });
209
210 let mut multiplicities_slopes = vec![F::ZERO; GR4_MODULUS as usize];
211 counts_slopes.iter().for_each(|(k, v)| {
212 multiplicities_slopes[k.to_u64s_le()[0] as usize] = F::from(*v);
213 });
214 let num_elements = num_copies * 4;
216 let num_zeros = num_elements.next_power_of_two() - num_elements;
217 multiplicities_slopes[0] += F::from(num_zeros as u64);
218
219 let mut counts_quotients: HashMap<F, u64> = HashMap::new();
221 quotients.iter().for_each(|x| {
222 assert!(x < &F::from(1 << 19));
224
225 *counts_quotients.entry(*x).or_insert(0) += 1;
226 });
227
228 let mut multiplicities_quotients = vec![F::ZERO; (1 << 19) as usize];
229 counts_quotients.iter().for_each(|(k, v)| {
230 multiplicities_quotients[k.to_u64s_le()[0] as usize] = F::from(*v);
231 });
232 let num_elements = num_copies * 4;
234 assert_eq!(quotients.len(), num_elements);
235 let num_zeros = num_elements.next_power_of_two() - num_elements;
236 multiplicities_quotients[0] += F::from(num_zeros as u64);
237
238 quotients
239 .iter()
240 .zip(shares_before_modulo_gr4.iter())
241 .zip(expected_shares.iter())
242 .for_each(|((quotient, share_before_modulo_gr4), expected_share)| {
243 assert_eq!(
244 *quotient * F::from(GR4_MODULUS) + expected_share,
245 *share_before_modulo_gr4
246 );
247 });
248
249 for f in shares_before_modulo_gr4.iter_mut() {
251 f.zeroize();
252 }
253 for f in masked_iris_codes.iter_mut() {
255 f.zeroize();
256 }
257
258 let quotients = MultilinearExtension::new(quotients);
259 let shares_reduced_modulo_gr4_modulus = MultilinearExtension::new(expected_shares);
260 let multiplicities_shares = MultilinearExtension::new(multiplicities_shares);
261 let multiplicities_slopes = MultilinearExtension::new(multiplicities_slopes);
262 let multiplicities_quotients = MultilinearExtension::new(multiplicities_quotients);
263
264 MPCCircuitInputData::<F> {
265 iris_codes: iris_codes.clone(),
266 masks: masks.clone(),
267 slopes: slopes.clone(),
268 quotients,
269 shares_reduced_modulo_gr4_modulus,
270 multiplicities_shares,
271 multiplicities_slopes,
272 multiplicities_quotients,
273 }
275}
276
277pub fn generate_trivial_test_data<
279 F: Field,
280 const NUM_IRIS_4_CHUNKS: usize,
281 const PARTY_IDX: usize,
282>() -> (MPCCircuitConstData<F>, MPCCircuitInputData<F>) {
283 let num_copies = NUM_IRIS_4_CHUNKS;
284 let mut rng = rand::thread_rng();
285
286 let iris_codes = MultilinearExtension::new(
287 (0..4 * num_copies)
288 .map(|_| F::from(rng.gen_range(0..=1)))
289 .collect(),
290 );
291 let masks = MultilinearExtension::new(
292 (0..4 * num_copies)
293 .map(|_| F::from(rng.gen_range(0..=1)))
294 .collect(),
295 );
296 let slopes = MultilinearExtension::new(
297 (0..4 * num_copies)
298 .map(|_| F::from(rng.gen_range(0..=(GR4_MODULUS - 1))))
299 .collect(),
300 );
301
302 let mpc_aux_data = gen_mpc_common_aux_data::<F, NUM_IRIS_4_CHUNKS, PARTY_IDX>();
303
304 let mpc_input_data = gen_mpc_input_data::<F, NUM_IRIS_4_CHUNKS>(
305 &iris_codes,
306 &masks,
307 &slopes,
308 &mpc_aux_data.encoding_matrix,
309 &mpc_aux_data.evaluation_points,
310 );
311
312 assert_eq!(mpc_input_data.quotients.len(), num_copies * 4);
313 assert_eq!(
314 mpc_input_data.shares_reduced_modulo_gr4_modulus.len(),
315 num_copies * 4
316 );
317 assert_eq!(slopes.len(), num_copies * 4);
318 assert_eq!(mpc_aux_data.evaluation_points.len(), num_copies * 4);
319 assert_eq!(
320 mpc_input_data.multiplicities_shares.len(),
321 GR4_MODULUS as usize
322 );
323 assert_eq!(
324 mpc_input_data.multiplicities_slopes.len(),
325 GR4_MODULUS as usize
326 );
327 assert_eq!(mpc_input_data.multiplicities_quotients.len(), (1 << 19));
328 assert_eq!(mpc_aux_data.lookup_table_values.len(), GR4_MODULUS as usize);
329
330 (mpc_aux_data, mpc_input_data)
331}
332
333pub fn fetch_inversed_test_data<
336 F: Field,
337 const NUM_IRIS_4_CHUNKS: usize,
338 const PARTY_IDX: usize,
339>(
340 test_idx: usize,
341) -> (MPCCircuitConstData<F>, MPCCircuitInputData<F>) {
342 let num_copies = NUM_IRIS_4_CHUNKS;
343 if test_idx + NUM_IRIS_4_CHUNKS >= TEST_MASKED_IRIS_CODES.len() {
344 panic!("test_idx out of range");
345 }
346 let mut rng = rand::thread_rng();
347
348 let masked_iris_codes = MultilinearExtension::new(
349 (0..num_copies)
350 .flat_map(|batch_idx| {
351 TEST_MASKED_IRIS_CODES[batch_idx + test_idx]
352 .into_iter()
353 .map(F::from)
354 .collect::<Vec<F>>()
355 })
356 .collect_vec(),
357 );
358 let iris_codes = MultilinearExtension::new(
359 (0..num_copies * 4)
360 .map(|_| F::from(rng.gen_range(0..=1)))
361 .collect(),
362 );
363 assert_eq!(masked_iris_codes.len(), iris_codes.len());
364 let masks = MultilinearExtension::new(
365 masked_iris_codes
366 .iter()
367 .zip(iris_codes.iter())
368 .map(|(masked_iris_code, iris_code)| F::from(2).neg() * iris_code - masked_iris_code)
369 .collect(),
370 );
371 let slopes = MultilinearExtension::new(
372 (0..num_copies)
373 .flat_map(|batch_idx| {
374 TEST_RANDOMNESSES[batch_idx + test_idx]
375 .into_iter()
376 .map(F::from)
377 .collect::<Vec<F>>()
378 })
379 .collect_vec(),
380 );
381
382 let mpc_aux_data = gen_mpc_common_aux_data::<F, NUM_IRIS_4_CHUNKS, PARTY_IDX>();
383 let mpc_input_data = gen_mpc_input_data::<F, NUM_IRIS_4_CHUNKS>(
384 &iris_codes,
385 &masks,
386 &slopes,
387 &mpc_aux_data.encoding_matrix,
388 &mpc_aux_data.evaluation_points,
389 );
390
391 assert_eq!(mpc_input_data.quotients.len(), num_copies * 4);
392 assert_eq!(
393 mpc_input_data.shares_reduced_modulo_gr4_modulus.len(),
394 num_copies * 4
395 );
396 assert_eq!(slopes.len(), num_copies * 4);
397 assert_eq!(mpc_aux_data.evaluation_points.len(), num_copies * 4);
398 assert_eq!(
399 mpc_input_data.multiplicities_shares.len(),
400 GR4_MODULUS as usize
401 );
402 assert_eq!(
403 mpc_input_data.multiplicities_slopes.len(),
404 GR4_MODULUS as usize
405 );
406 assert_eq!(mpc_aux_data.lookup_table_values.len(), GR4_MODULUS as usize);
407
408 mpc_input_data
409 .shares_reduced_modulo_gr4_modulus
410 .iter()
411 .zip(
412 (0..num_copies)
413 .flat_map(|batch_idx| TEST_SHARES[PARTY_IDX][batch_idx + test_idx].into_iter())
414 .collect_vec()
415 .iter(),
416 )
417 .for_each(|(a, b)| {
418 assert_eq!(a, F::from(*b));
419 });
420
421 (mpc_aux_data, mpc_input_data)
422}