shared_types/transcript/
ec_transcript.rs

1use std::fmt::Display;
2
3use itertools::Itertools;
4
5use crate::{transcript::utils::sha256_hash_chain_on_field_elems, HasByteRepresentation};
6
7use crate::curves::PrimeOrderCurve;
8
9use super::{ProverTranscript, Transcript, TranscriptSponge};
10
11pub trait ECTranscriptSponge<C: PrimeOrderCurve>: TranscriptSponge<C::Base> {
12    /// Absorb a single field element `elem`.
13    fn absorb_ec_point(&mut self, elem: C);
14
15    /// Absorb a list of field elements sequentially.
16    fn absorb_ec_points(&mut self, elements: &[C]);
17}
18
19impl<C, Tr> ECTranscriptSponge<C> for Tr
20where
21    C: PrimeOrderCurve,
22    Tr: TranscriptSponge<C::Base>,
23{
24    fn absorb_ec_point(&mut self, elem: C) {
25        let (x, y) = elem.affine_coordinates().unwrap();
26        self.absorb(x);
27        self.absorb(y);
28    }
29
30    fn absorb_ec_points(&mut self, elements: &[C]) {
31        elements.iter().for_each(|elem| {
32            let (x, y) = elem.affine_coordinates().unwrap();
33            self.absorb(x);
34            self.absorb(y);
35        });
36    }
37}
38
39/// The purposes of this trait is simply to hide (i.e. abstract away) the generic for the sponge
40/// type from the prover and verifier code.
41pub trait ECTranscriptTrait<C: PrimeOrderCurve>: Display {
42    fn append_ec_point(&mut self, label: &str, elem: C);
43
44    fn append_ec_points(&mut self, label: &str, elements: &[C]);
45
46    fn append_scalar_field_elem(&mut self, label: &str, elem: C::Scalar);
47
48    fn append_scalar_field_elems(&mut self, label: &str, elements: &[C::Scalar]);
49
50    /// This function absorbs elliptic curve points as individual base field
51    /// elements, and additionally absorbs the hash chain digest of the
52    /// base field elements.
53    fn append_input_ec_points(&mut self, label: &str, elements: Vec<C>);
54
55    /// This function absorbs scalar field elements into the transcript sponge,
56    /// and additionally absorbs the hash chain digest of these elements.
57    fn append_input_scalar_field_elems(&mut self, label: &str, elements: &[C::Scalar]);
58
59    /// This function absorbs base field elements into the transcript sponge.
60    fn append_base_field_elems(&mut self, label: &str, elements: &[C::Base]);
61
62    fn get_scalar_field_challenge(&mut self, label: &str) -> C::Scalar;
63
64    fn get_scalar_field_challenges(&mut self, label: &str, num_elements: usize) -> Vec<C::Scalar>;
65
66    fn get_ec_challenge(&mut self, label: &str) -> C;
67
68    fn get_ec_challenges(&mut self, label: &str, num_elements: usize) -> Vec<C>;
69}
70
71/// A transcript that operates over the base field of a prime-order curve, while also allowing for
72/// the absorption and sampling of scalar field elements (and of course, EC points).
73pub struct ECTranscript<C: PrimeOrderCurve, T> {
74    /// The sponge that this writer is using to append/squeeze elements.
75    sponge: T,
76
77    /// A mutable transcript which keeps a record of all the append/squeeze
78    /// operations.
79    transcript: Transcript<C::Base>,
80
81    /// Whether to print debug information.
82    debug: bool,
83}
84
85impl<C: PrimeOrderCurve, Tr: ECTranscriptSponge<C> + Default> ECTranscript<C, Tr> {
86    /// Destructively extract the transcript produced by this writer.
87    /// This should be the last operation performed on a `TranscriptWriter`.
88    pub fn get_transcript(self) -> Transcript<C::Base> {
89        self.transcript
90    }
91
92    /// Creates an empty sponge.
93    /// `label` is an identifier used for debugging purposes.
94    pub fn new(label: &str) -> Self {
95        Self {
96            sponge: Tr::default(),
97            transcript: Transcript::new(label),
98            debug: false,
99        }
100    }
101
102    /// Creates an empty sponge in debug mode (i.e. with debug information printed).
103    /// `label` is an identifier used for debugging purposes.
104    pub fn new_with_debug(label: &str) -> Self {
105        Self {
106            sponge: Tr::default(),
107            transcript: Transcript::new(label),
108            debug: true,
109        }
110    }
111}
112
113impl<C: PrimeOrderCurve, Tr: ECTranscriptSponge<C> + Default> ECTranscriptTrait<C>
114    for ECTranscript<C, Tr>
115{
116    fn append_ec_point(&mut self, label: &str, elem: C) {
117        let (x_coord, y_coord) = elem.affine_coordinates().unwrap();
118        self.append_elements(label, &[x_coord, y_coord]);
119    }
120
121    fn append_ec_points(&mut self, label: &str, elements: &[C]) {
122        elements.iter().for_each(|elem| {
123            let (x_coord, y_coord) = elem.affine_coordinates().unwrap();
124            self.append_elements(label, &[x_coord, y_coord]);
125        });
126    }
127
128    fn append_scalar_field_elem(&mut self, label: &str, elem: C::Scalar) {
129        let base_elem = C::Base::from_bytes_le(&elem.to_bytes_le());
130        self.append(label, base_elem);
131    }
132
133    fn append_scalar_field_elems(&mut self, label: &str, elements: &[C::Scalar]) {
134        elements.iter().for_each(|elem| {
135            let base_elem = C::Base::from_bytes_le(&elem.to_bytes_le());
136            self.append(label, base_elem);
137        });
138    }
139
140    /// Literally takes the byte representation of the base field element and
141    /// dumps it (TODO: in an unsafe manner! Make this return an error rather
142    /// than just panicking) into a scalar field element's representation.
143    fn get_scalar_field_challenge(&mut self, label: &str) -> <C as PrimeOrderCurve>::Scalar {
144        let base_field_challenge = self.get_challenge(label);
145        C::Scalar::from_bytes_le(&base_field_challenge.to_bytes_le())
146    }
147
148    fn get_scalar_field_challenges(
149        &mut self,
150        label: &str,
151        num_elements: usize,
152    ) -> Vec<<C as PrimeOrderCurve>::Scalar> {
153        let base_field_challenges = self.get_challenges(label, num_elements);
154        base_field_challenges
155            .iter()
156            .map(|base_field_challenge| {
157                C::Scalar::from_bytes_le(&base_field_challenge.to_bytes_le())
158            })
159            .collect()
160    }
161
162    /// Generates two base field elements, and uses only the parity of the second
163    /// to determine the actual `y`-coordinate to be used.
164    ///
165    /// WARNING/TODO(ryancao): USING THIS FUNCTION `num_elements` TIMES WILL
166    /// NOT PRODUCE THE SAME EC CHALLENGES AS CALLING [Self::get_ec_challenges]
167    /// WITH `num_elements` AS A PARAMETER!!!
168    ///
169    /// IN PARTICULAR, THIS FUNCTION
170    /// GENERATES (x, y) ELEMENTS IN INDIVIDUAL PAIRS, WHILE THE
171    /// [Self::get_ec_challenges] FUNCTION GENERATES (x, y) ELEMENTS BY FIRST
172    /// GENERATING ALL x-coordinates AND THEN GENERATING ALL ELEMENTS DETERMINING
173    /// THE PARITY OF THE CORRESPONDING y-coordinates.
174    fn get_ec_challenge(&mut self, label: &str) -> C {
175        let x_coord_label = label.to_string() + ": x-coord";
176        let x_coord = self.get_challenge(&x_coord_label);
177
178        let y_coord_sign_elem_label = label.to_string() + ": y-coord sign elem";
179        let y_coord_sign_elem = self.get_challenge(&y_coord_sign_elem_label);
180        let y_coord_sign = y_coord_sign_elem.to_bytes_le()[0] & 1;
181
182        C::from_x_and_sign_y(x_coord, y_coord_sign)
183    }
184
185    /// Generates two base field elements for each element requested, by FIRST
186    /// generating ALL of the x-coords and AFTERWARDS generating ALL of the
187    /// base field elements whose parity determines the sign of the corresponding
188    /// y-coord.
189    ///
190    /// WARNING/TODO(ryancao): SEE WARNING FOR [Self::get_ec_challenge]!!!
191    fn get_ec_challenges(&mut self, label: &str, num_elements: usize) -> Vec<C> {
192        let x_coord_label = label.to_string() + ": x-coords";
193        let y_coord_sign_elem_label = label.to_string() + ": y-coord sign elems";
194
195        let x_coords = self.get_challenges(&x_coord_label, num_elements);
196        let y_coord_sign_elems = self.get_challenges(&y_coord_sign_elem_label, num_elements);
197
198        let y_coord_signs = y_coord_sign_elems
199            .iter()
200            .map(|y_coord_sign_elem| y_coord_sign_elem.to_bytes_le()[0] & 1);
201
202        x_coords
203            .into_iter()
204            .zip(y_coord_signs)
205            .map(|(x_coord, y_coord_sign)| C::from_x_and_sign_y(x_coord, y_coord_sign))
206            .collect()
207    }
208
209    fn append_input_ec_points(&mut self, label: &str, elements: Vec<C>) {
210        // We compute the list of all x-coordinates interwoven with all
211        // y-coordinates, i.e. [x_1, y_1, x_2, y_2, ...]
212        let elements_interwoven_x_y_coords = elements
213            .into_iter()
214            .map(|ec_element| ec_element.affine_coordinates().unwrap())
215            .flat_map(|(x, y)| vec![x, y])
216            .collect_vec();
217        // We then compute the hash chain digest of the list.
218        let hash_chain_digest = sha256_hash_chain_on_field_elems(&elements_interwoven_x_y_coords);
219        // We first absorb the interwoven x/y coordinates, then the hash chain (both as native base field elements).
220        self.append_base_field_elems(label, &elements_interwoven_x_y_coords);
221        self.append_base_field_elems(label, &hash_chain_digest);
222    }
223
224    fn append_input_scalar_field_elems(
225        &mut self,
226        label: &str,
227        elements: &[<C as PrimeOrderCurve>::Scalar],
228    ) {
229        // First, compute the hash chain digest of the elements.
230        let hash_chain_digest = sha256_hash_chain_on_field_elems(elements);
231        // Next, we simply absorb the elements and then the hash chain digest of the elements.
232        self.append_scalar_field_elems(label, elements);
233        self.append_scalar_field_elems(label, &hash_chain_digest);
234    }
235
236    fn append_base_field_elems(&mut self, label: &str, elements: &[<C as PrimeOrderCurve>::Base]) {
237        // This is a simple wrapper around the `ProverTranscript<C::Base>` trait.
238        self.append_elements(label, elements);
239    }
240}
241
242impl<C: PrimeOrderCurve, Sp: ECTranscriptSponge<C>> std::fmt::Display for ECTranscript<C, Sp> {
243    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244        self.transcript.fmt(f)
245    }
246}
247
248impl<C: PrimeOrderCurve, Sp: TranscriptSponge<C::Base>> ProverTranscript<C::Base>
249    for ECTranscript<C, Sp>
250{
251    fn append(&mut self, label: &str, elem: C::Base) {
252        if self.debug {
253            println!("Appending element (\"{label}\"): {elem:?}");
254        }
255        self.sponge.absorb(elem);
256        self.transcript.append_elements(label, &[elem]);
257    }
258
259    fn append_elements(&mut self, label: &str, elements: &[C::Base]) {
260        if !elements.is_empty() {
261            if self.debug {
262                println!(
263                    "Appending {} elements (\"{}\"): [{:?}, .., ]",
264                    elements.len(),
265                    label,
266                    elements[0]
267                );
268            }
269            self.sponge.absorb_elements(elements);
270            self.transcript.append_elements(label, elements);
271        }
272    }
273
274    fn get_challenge(&mut self, label: &str) -> C::Base {
275        let challenge = self.sponge.squeeze();
276        self.transcript.squeeze_elements(label, 1);
277        if self.debug {
278            println!("Squeezing challenge (\"{label}\"): {challenge:?}");
279        }
280        challenge
281    }
282
283    fn get_challenges(&mut self, label: &str, num_elements: usize) -> Vec<C::Base> {
284        if num_elements == 0 {
285            vec![]
286        } else {
287            let challenges = self.sponge.squeeze_elements(num_elements);
288            self.transcript.squeeze_elements(label, num_elements);
289            if self.debug {
290                println!(
291                    "Squeezing {} challenges (\"{}\"): [{:?}, .., ]",
292                    num_elements, label, challenges[0]
293                );
294            }
295            challenges
296        }
297    }
298
299    fn append_input_elements(&mut self, label: &str, elements: &[C::Base]) {
300        let hash_chain_digest = sha256_hash_chain_on_field_elems(elements);
301        self.transcript
302            .append_input_elements(label, elements, &hash_chain_digest);
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::{ECTranscript, ECTranscriptTrait};
309    use crate::{
310        curves::PrimeOrderCurve,
311        transcript::{poseidon_sponge::PoseidonSponge, ProverTranscript},
312        Base, Bn256Point, Scalar,
313    };
314    use ark_std::test_rng;
315    use itertools::Itertools;
316    use rand::Rng;
317
318    /// A basic test which ensures that the transcript challenge which would
319    /// have been derived from simply calling `append_scalar_field_elems` is
320    /// NOT the same as the one which would've been derived from calling
321    /// `append_input_scalar_field_elems`.
322    #[test]
323    fn test_ec_scalar_input_hash_soundness() {
324        // Create transcripts to compare
325        let mut ec_transcript_1: ECTranscript<Bn256Point, PoseidonSponge<Base>> =
326            ECTranscript::new("test ec_transcript_1");
327        let mut ec_transcript_2: ECTranscript<Bn256Point, PoseidonSponge<Base>> =
328            ECTranscript::new("test ec_transcript_2");
329
330        // Generate random input elements
331        let mut rng = test_rng();
332        let random_input_elems = (0..1000)
333            .map(|_| Scalar::from(rng.gen::<u64>()))
334            .collect_vec();
335        ec_transcript_1
336            .append_scalar_field_elems("test append scalar field elems", &random_input_elems);
337        ec_transcript_2.append_input_scalar_field_elems(
338            "test append input scalar field elems",
339            &random_input_elems,
340        );
341
342        // Generate challenges from each and assert they are not the same
343        assert_ne!(
344            ec_transcript_1.get_challenge("get challenge 1"),
345            ec_transcript_2.get_challenge("get challenge 2")
346        );
347    }
348
349    /// A basic test which ensures that the transcript challenge which would
350    /// have been derived from simply calling `append_ec_points` is
351    /// NOT the same as the one which would've been derived from calling
352    /// `append_input_ec_points`.
353    #[test]
354    fn test_ec_point_input_hash_soundness() {
355        // Create transcripts to compare
356        let mut ec_transcript_1: ECTranscript<Bn256Point, PoseidonSponge<Base>> =
357            ECTranscript::new("test ec_transcript_1");
358        let mut ec_transcript_2: ECTranscript<Bn256Point, PoseidonSponge<Base>> =
359            ECTranscript::new("test ec_transcript_2");
360
361        // Generate random input elements
362        let mut rng = test_rng();
363        let random_input_ec_points = (0..1000)
364            .map(|_| Bn256Point::random(&mut rng))
365            .collect_vec();
366        ec_transcript_1.append_ec_points("test append ec elems", &random_input_ec_points);
367        ec_transcript_2
368            .append_input_ec_points("test append input ec elems", random_input_ec_points);
369
370        // Generate challenges from each and assert they are not the same
371        assert_ne!(
372            ec_transcript_1.get_challenge("get challenge 1"),
373            ec_transcript_2.get_challenge("get challenge 2")
374        );
375    }
376}