1#![allow(clippy::type_complexity)]
2
3use crate::components::digits::DigitComponents;
4use crate::layouter::builder::{Circuit, CircuitBuilder, LayerVisibility};
5use crate::zk_iriscode_ss::components::ZkIriscodeComponent;
6use crate::zk_iriscode_ss::data::IriscodeCircuitAuxData;
7use remainder::mle::evals::MultilinearExtension;
8use remainder::utils::arithmetic::log2_ceil;
9
10use itertools::Itertools;
11use shared_types::Field;
12
13use super::data::IriscodeCircuitInputData;
14
15use anyhow::Result;
16
17pub const V3_INPUT_IMAGE_LAYER: &str = "Input image (to reroute)";
19pub const V3_DIGITS_LAYER: &str = "Digit values and multiplicities";
21pub const V3_SIGN_BITS_LAYER: &str = "Sign Bits";
23pub const V3_AUXILIARY_LAYER: &str = "Auxiliary Data";
25
26pub const V3_INPUT_IMAGE_SHRED: &str = "Image to reroute";
27pub const V3_DIGITS_SHRED_TEMPLATE: &str = "Digits Input Shred";
28pub const V3_DIGITS_MULTIPLICITIES_SHRED: &str = "Digits multiplicities";
29pub const V3_TO_SUB_MATMULT_SHRED: &str = "Input to subtract from MatMult";
30pub const V3_RH_MATMULT_SHRED: &str = "RH Multiplicand of MatMult";
31pub const V3_LOOKUP_SHRED: &str = "Lookup table values for digit range check";
32pub const V3_SIGN_BITS_SHRED: &str = "Sign Bits";
33
34pub fn build_iriscode_circuit_description<
36 F: Field,
37 const IM_STRIP_ROWS: usize,
38 const IM_STRIP_COLS: usize,
39 const IM_NUM_VARS: usize,
40 const MATMULT_ROWS_NUM_VARS: usize,
41 const MATMULT_COLS_NUM_VARS: usize,
42 const MATMULT_INTERNAL_DIM_NUM_VARS: usize,
43 const BASE: u64,
44 const NUM_DIGITS: usize,
45>(
46 layer_visibility: LayerVisibility,
47 image_strip_reroutings: Vec<Vec<(u32, u32)>>,
48 lh_matrix_reroutings: Vec<(u32, u32)>,
49) -> Result<Circuit<F>> {
50 let mut builder = CircuitBuilder::<F>::new();
51
52 assert!(BASE.is_power_of_two());
53 let log_base = log2_ceil(BASE) as usize;
54 let num_strips = image_strip_reroutings.len();
55 assert!(num_strips.is_power_of_two());
56 let log_num_strips = log2_ceil(num_strips) as usize;
57
58 let to_reroute_input_layer = builder.add_input_layer(V3_INPUT_IMAGE_LAYER, layer_visibility);
60 let to_reroute =
61 builder.add_input_shred(V3_INPUT_IMAGE_SHRED, IM_NUM_VARS, &to_reroute_input_layer);
62
63 let digits_input_layer = builder.add_input_layer(V3_DIGITS_LAYER, layer_visibility);
65 let digits_input_shreds: Vec<_> = (0..NUM_DIGITS)
66 .map(|i| {
67 builder.add_input_shred(
68 &format!("{V3_DIGITS_SHRED_TEMPLATE} {i}"),
69 log_num_strips + MATMULT_ROWS_NUM_VARS + MATMULT_COLS_NUM_VARS,
70 &digits_input_layer,
71 )
72 })
73 .collect();
74
75 let digit_multiplicities = builder.add_input_shred(
76 V3_DIGITS_MULTIPLICITIES_SHRED,
77 log_base,
78 &digits_input_layer,
79 );
80
81 let auxiliary_input_layer =
83 builder.add_input_layer(V3_AUXILIARY_LAYER, LayerVisibility::Public);
84
85 let to_sub_from_matmult = builder.add_input_shred(
86 V3_TO_SUB_MATMULT_SHRED,
87 log_num_strips + MATMULT_ROWS_NUM_VARS + MATMULT_COLS_NUM_VARS,
88 &auxiliary_input_layer,
89 );
90
91 let rh_matmult_multiplicand = builder.add_input_shred(
92 V3_RH_MATMULT_SHRED,
93 MATMULT_INTERNAL_DIM_NUM_VARS + MATMULT_COLS_NUM_VARS,
94 &auxiliary_input_layer,
95 );
96
97 let lookup_table_values =
98 builder.add_input_shred(V3_LOOKUP_SHRED, log_base, &auxiliary_input_layer);
99
100 let sign_bits_input_layer = builder.add_input_layer(V3_SIGN_BITS_LAYER, layer_visibility);
102 let sign_bits = builder.add_input_shred(
103 V3_SIGN_BITS_SHRED,
104 log_num_strips + MATMULT_ROWS_NUM_VARS + MATMULT_COLS_NUM_VARS,
105 &sign_bits_input_layer,
106 );
107
108 let rlc_challenges = (0..num_strips)
110 .map(|_| builder.add_fiat_shamir_challenge_node(1))
111 .collect_vec();
112 let rlc_challenges_generic = rlc_challenges
113 .clone()
114 .into_iter()
115 .map(|node| node.into())
116 .collect_vec();
117
118 let lookup_challenge = builder.add_fiat_shamir_challenge_node(1);
120
121 let image_strip_nodes = image_strip_reroutings
125 .into_iter()
126 .map(|reroutings| {
127 builder.add_identity_gate_node(
128 &to_reroute,
129 reroutings,
130 log2_ceil(IM_STRIP_ROWS * IM_STRIP_COLS) as usize,
131 None,
132 )
133 })
134 .collect_vec();
135
136 let image_rlc = ZkIriscodeComponent::sum_of_products(
138 &mut builder,
139 rlc_challenges_generic.iter().collect(),
140 image_strip_nodes.iter().collect(),
141 );
142
143 let rerouted_image = builder.add_identity_gate_node(
145 &image_rlc,
146 lh_matrix_reroutings,
147 MATMULT_ROWS_NUM_VARS + MATMULT_INTERNAL_DIM_NUM_VARS,
148 None,
149 );
150
151 let matmult = builder.add_matmult_node(
153 &rerouted_image,
154 (MATMULT_ROWS_NUM_VARS, MATMULT_INTERNAL_DIM_NUM_VARS),
155 &rh_matmult_multiplicand,
156 (MATMULT_INTERNAL_DIM_NUM_VARS, MATMULT_COLS_NUM_VARS),
157 );
158
159 let to_sub_from_matmult_splits = builder.add_split_node(&to_sub_from_matmult, log_num_strips);
161
162 let to_sub_from_matmult_rlc = ZkIriscodeComponent::sum_of_products(
163 &mut builder,
164 rlc_challenges_generic.iter().collect(),
165 to_sub_from_matmult_splits.iter().collect(),
166 );
167
168 let subtractor = builder.add_sector(matmult - to_sub_from_matmult_rlc);
170
171 let digits_split_nodes = digits_input_shreds
173 .iter()
174 .map(|shred| builder.add_split_node(shred, log_num_strips))
175 .collect_vec();
176 let digits_rlc = digits_split_nodes
177 .iter()
178 .map(|splits| {
179 let digit_rlc = ZkIriscodeComponent::sum_of_products(
180 &mut builder,
181 rlc_challenges_generic.iter().collect(),
182 splits.iter().collect(),
183 );
184 digit_rlc
185 })
186 .collect_vec();
187
188 let digits_concatenator = DigitComponents::digits_concatenator(
191 &mut builder,
192 &digits_input_shreds.iter().collect_vec(),
193 );
194
195 let lookup_table = builder.add_lookup_table(&lookup_table_values, &lookup_challenge);
197 let _lookup_constraint =
199 builder.add_lookup_constraint(&lookup_table, &digits_concatenator, &digit_multiplicities);
200 let unsigned_recomp = DigitComponents::unsigned_recomposition(
204 &mut builder,
205 &digits_rlc.iter().collect_vec(),
206 BASE,
207 );
208
209 let sign_bits_splits = builder.add_split_node(&sign_bits, log_num_strips);
211
212 let sign_bits_rlc = ZkIriscodeComponent::sum_of_products(
213 &mut builder,
214 rlc_challenges_generic.iter().collect(),
215 sign_bits_splits.iter().collect(),
216 );
217
218 let complementary_checker = DigitComponents::complementary_recomp_check(
220 &mut builder,
221 &subtractor,
222 &sign_bits_rlc,
223 &unsigned_recomp,
224 BASE,
225 NUM_DIGITS,
226 );
227 builder.set_output(&complementary_checker);
228
229 let bits_are_binary = DigitComponents::bits_are_binary(&mut builder, &sign_bits);
230 builder.set_output(&bits_are_binary);
231
232 builder.build_without_layer_combination()
234}
235
236pub fn iriscode_ss_attach_aux_data<F: Field, const BASE: u64>(
237 mut circuit: Circuit<F>,
238 iriscode_aux_data: IriscodeCircuitAuxData<F>,
239) -> Result<Circuit<F>> {
240 circuit.set_input(
241 V3_RH_MATMULT_SHRED,
242 iriscode_aux_data.rh_matmult_multiplicand,
243 );
244
245 circuit.set_input(
246 V3_TO_SUB_MATMULT_SHRED,
247 iriscode_aux_data.to_sub_from_matmult,
248 );
249
250 circuit.set_input(
251 V3_LOOKUP_SHRED,
252 MultilinearExtension::new((0..BASE).map(F::from).collect()),
253 );
254
255 Ok(circuit)
256}
257
258pub fn iriscode_ss_attach_input_data<F: Field, const BASE: u64>(
262 mut circuit: Circuit<F>,
263 iriscode_input_data: IriscodeCircuitInputData<F>,
264 iriscode_aux_data: IriscodeCircuitAuxData<F>,
265) -> Result<Circuit<F>> {
266 circuit.set_input(V3_INPUT_IMAGE_SHRED, iriscode_input_data.to_reroute);
267 circuit.set_input(
268 V3_RH_MATMULT_SHRED,
269 iriscode_aux_data.rh_matmult_multiplicand,
270 );
271
272 iriscode_input_data
273 .digits
274 .into_iter()
275 .enumerate()
276 .for_each(|(i, mle)| {
277 circuit.set_input(&format!("{V3_DIGITS_SHRED_TEMPLATE} {i}"), mle);
278 });
279
280 circuit.set_input(V3_SIGN_BITS_SHRED, iriscode_input_data.sign_bits);
281 circuit.set_input(
282 V3_TO_SUB_MATMULT_SHRED,
283 iriscode_aux_data.to_sub_from_matmult,
284 );
285 circuit.set_input(
286 V3_DIGITS_MULTIPLICITIES_SHRED,
287 iriscode_input_data.digit_multiplicities,
288 );
289 circuit.set_input(
290 V3_LOOKUP_SHRED,
291 MultilinearExtension::new((0..BASE).map(F::from).collect()),
292 );
293
294 Ok(circuit)
295}