1use crate::digits::{complementary_decomposition, digits_to_field, to_slice_of_mles};
2use itertools::Itertools;
3use ndarray::{Array, Array2};
4use remainder::mle::evals::MultilinearExtension;
5use remainder::utils::arithmetic::i64_to_field;
6use remainder::utils::mle::pad_with;
7use serde::{Deserialize, Serialize};
8use shared_types::Field;
9
10#[derive(Debug, Clone)]
12pub struct IriscodeCircuitInputData<F: Field> {
13 pub to_reroute: MultilinearExtension<F>,
16 pub digits: Vec<MultilinearExtension<F>>,
19 pub sign_bits: MultilinearExtension<F>,
24 pub digit_multiplicities: MultilinearExtension<F>,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32#[serde(bound = "F: Field")]
33pub struct IriscodeCircuitAuxData<F: Field> {
34 pub rh_matmult_multiplicand: MultilinearExtension<F>,
37
38 pub to_sub_from_matmult: MultilinearExtension<F>,
41}
42
43pub fn wirings_to_reroutings(
49 wirings: &[(u16, u16, u16, u16)],
50 src_arr_num_cols: usize,
51 dest_arr_num_cols: usize,
52) -> Vec<(u32, u32)> {
53 wirings
54 .iter()
55 .map(|row| {
56 let (src_row_idx, src_col_idx, dest_row_idx, dest_col_idx) = (
57 row.0 as usize,
58 row.1 as usize,
59 row.2 as usize,
60 row.3 as usize,
61 );
62 let src_idx = src_row_idx * src_arr_num_cols + src_col_idx;
63 let dest_idx = dest_row_idx * dest_arr_num_cols + dest_col_idx;
64 (dest_idx as u32, src_idx as u32)
65 })
66 .collect_vec()
67}
68
69pub fn build_iriscode_circuit_auxiliary_data<
70 F: Field,
71 const MATMULT_COLS_NUM_VARS: usize,
72 const MATMULT_INTERNAL_DIM_NUM_VARS: usize,
73 const NUM_STRIPS: usize,
74 const MAT_CHUNK_SIZE: usize,
75>(
76 rh_multiplicand: &[i32],
77 thresholds_matrix: &[i64],
78) -> IriscodeCircuitAuxData<F> {
79 let rh_multiplicand = Array2::from_shape_vec(
81 (
82 1 << MATMULT_INTERNAL_DIM_NUM_VARS,
83 1 << MATMULT_COLS_NUM_VARS,
84 ),
85 rh_multiplicand.iter().map(|&x| x as i64).collect_vec(),
86 )
87 .unwrap();
88
89 let rh_matmult_multiplicand: Vec<F> =
91 rh_multiplicand.into_iter().map(i64_to_field).collect_vec();
92
93 let thresholds_matrix = Array2::from_shape_vec(
95 (NUM_STRIPS * MAT_CHUNK_SIZE, 1 << MATMULT_COLS_NUM_VARS),
96 thresholds_matrix.to_vec(),
97 )
98 .unwrap();
99
100 let thresholds_matrix: Vec<F> = pad_with(
102 F::ZERO,
103 &thresholds_matrix
104 .into_iter()
105 .map(i64_to_field)
106 .collect_vec(),
107 );
108
109 IriscodeCircuitAuxData {
110 rh_matmult_multiplicand: MultilinearExtension::new(rh_matmult_multiplicand),
111 to_sub_from_matmult: MultilinearExtension::new(thresholds_matrix),
112 }
113}
114
115pub fn build_iriscode_circuit_data<
118 F: Field,
119 const IM_STRIP_ROWS: usize,
120 const IM_STRIP_COLS: usize,
121 const MATMULT_ROWS_NUM_VARS: usize,
122 const MATMULT_COLS_NUM_VARS: usize,
123 const MATMULT_INTERNAL_DIM_NUM_VARS: usize,
124 const BASE: u64,
125 const NUM_DIGITS: usize,
126>(
127 image: Array2<u8>,
128 rh_multiplicand: &[i32],
129 thresholds_matrix: &[i64],
130 image_strip_wirings: Vec<Vec<(u16, u16, u16, u16)>>,
131 lh_matrix_wirings: &[(u16, u16, u16, u16)],
132) -> IriscodeCircuitInputData<F> {
133 assert!(BASE.is_power_of_two());
134 assert!(NUM_DIGITS.is_power_of_two());
135 let num_strips = image_strip_wirings.len();
136
137 let mat_chunk_size = 1 << MATMULT_ROWS_NUM_VARS;
139 let mut rerouted_matrix: Array2<i64> = Array::zeros((
140 num_strips * mat_chunk_size,
141 (1 << MATMULT_INTERNAL_DIM_NUM_VARS),
142 ));
143 image_strip_wirings
144 .iter()
145 .enumerate()
146 .for_each(|(strip_idx, wirings)| {
147 let mut image_strip: Array2<i64> = Array::zeros((IM_STRIP_ROWS, IM_STRIP_COLS));
149 wirings.iter().for_each(|row| {
150 let (im_row, im_col, im_strip_row, im_strip_col) = (
151 row.0 as usize,
152 row.1 as usize,
153 row.2 as usize,
154 row.3 as usize,
155 );
156 image_strip[[im_strip_row, im_strip_col]] = image[[im_row, im_col]] as i64;
157 });
158 lh_matrix_wirings.iter().for_each(|row| {
160 let (im_strip_row, im_strip_col, mat_row, mat_col) = (
161 row.0 as usize,
162 row.1 as usize,
163 row.2 as usize,
164 row.3 as usize,
165 );
166 rerouted_matrix[[strip_idx * mat_chunk_size + mat_row, mat_col]] =
167 image_strip[[im_strip_row, im_strip_col]];
168 });
169 });
170
171 let rh_multiplicand = Array2::from_shape_vec(
173 (
174 1 << MATMULT_INTERNAL_DIM_NUM_VARS,
175 1 << MATMULT_COLS_NUM_VARS,
176 ),
177 rh_multiplicand.iter().map(|&x| x as i64).collect_vec(),
178 )
179 .unwrap();
180
181 let thresholds_matrix = Array2::from_shape_vec(
183 (num_strips * mat_chunk_size, 1 << MATMULT_COLS_NUM_VARS),
184 thresholds_matrix.to_vec(),
185 )
186 .unwrap();
187
188 let responses = rerouted_matrix.dot(&rh_multiplicand);
190
191 let thres_resp = pad_with(
196 0,
197 &(responses - &thresholds_matrix).into_iter().collect_vec(),
198 );
199
200 let (digits, code): (Vec<_>, Vec<_>) = thres_resp
203 .into_iter()
204 .map(|value| complementary_decomposition::<BASE, NUM_DIGITS>(value).unwrap())
205 .unzip();
206
207 let mut digit_multiplicities: Vec<usize> = vec![0; BASE as usize];
209 digits.iter().for_each(|decomp| {
210 decomp.iter().for_each(|&digit| {
211 digit_multiplicities[digit as usize] += 1;
212 })
213 });
214
215 let image_matrix_mle: Vec<F> = pad_with(0, &image.into_iter().collect_vec())
218 .into_iter()
219 .map(|v| F::from(v as u64))
220 .collect_vec();
221
222 let code: Vec<F> = code
224 .into_iter()
225 .map(|elem| F::from(elem as u64))
226 .collect_vec();
227
228 let digit_multiplicities = digit_multiplicities
230 .into_iter()
231 .map(|x| F::from(x as u64))
232 .collect_vec();
233 let digits = to_slice_of_mles(digits.iter().map(digits_to_field).collect_vec()).to_vec();
234
235 IriscodeCircuitInputData {
236 to_reroute: MultilinearExtension::new(image_matrix_mle),
237 digits,
238 sign_bits: MultilinearExtension::new(code),
239 digit_multiplicities: MultilinearExtension::new(digit_multiplicities),
240 }
241}