shared_types/
curves.rs

1use crate::{ff_field, Zeroizable};
2use crate::{
3    halo2curves::{bn256::G1 as Bn256, CurveExt},
4    Field, HasByteRepresentation,
5};
6use ark_std::rand::{self, RngCore};
7use halo2curves::bn256::{Fq, Fr};
8use itertools::Itertools;
9use num::traits::ToBytes;
10use num::{Unsigned, Zero};
11use rand::CryptoRng;
12use serde::{Deserialize, Serialize};
13use sha3::{
14    digest::{core_api::XofReaderCoreWrapper, XofReader},
15    Shake256ReaderCore,
16};
17/// Traits and implementations for elliptic curves of prime order.
18///
19/// Justification for creating own elliptic curve trait:
20/// + The trait in the halo2curves library that is closest to what is wanted is
21///   `CurveExt`.  However, the field trait they use for the base and scalar
22///   fields, viz. `WithSmallOrderMulGroup<3>` is not appropriate, however, as
23///   it restricts to finite fields for which p - 1 is divisible by 3.  This is
24///   an arbitrarily restriction from our POV (though it is satisfied by Bn254).
25///   (Further, we found the halo2curves traits are very difficult to parse).
26/// + The `AffineCurve` trait from `ark-ec` is precisely as specific as required
27///   and are beautifully written, but we'd need to implement the arkworks field
28///   traits for the fields we use from halo2.
29use std::{
30    fmt,
31    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
32};
33use subtle::{Choice, ConditionallySelectable};
34// Implementation note: do not confuse the following with the very similarly
35// named shared_types::halo2curves::Group
36use crate::halo2curves::group::Group;
37
38#[cfg(test)]
39/// Tests for prime order curve.
40mod tests;
41
42/// Minimal interface for an elliptic curve of prime order.
43pub trait PrimeOrderCurve:
44    Copy
45    + Clone
46    + Sized
47    + Send
48    + Sync
49    + fmt::Debug
50    + Eq
51    + 'static
52    + Neg<Output = Self>
53    + Mul<Self::Scalar, Output = Self>
54    + Add<Self, Output = Self>
55    + Sub<Self, Output = Self>
56    + AddAssign<Self>
57    + SubAssign<Self>
58    + MulAssign<Self::Scalar>
59    + Serialize
60    + for<'de> Deserialize<'de>
61    + Zeroizable
62{
63    /// The scalar field of the curve.
64    type Scalar: Field;
65
66    /// The base field of the curve.
67    type Base: Field;
68
69    /// The byte sizes for the serialized representations.
70    const UNCOMPRESSED_CURVE_POINT_BYTEWIDTH: usize;
71    const COMPRESSED_CURVE_POINT_BYTEWIDTH: usize;
72    const SCALAR_ELEM_BYTEWIDTH: usize;
73
74    /// Return the additive identity of the curve.
75    fn zero() -> Self;
76
77    /// Return the chosen generator of the curve.
78    fn generator() -> Self;
79
80    /// Returns an element chosen uniformly at random.
81    fn random(rng: impl RngCore) -> Self;
82
83    /// Return the point doubled.
84    fn double(&self) -> Self;
85
86    /// Return the projective coordinates of the point.
87    fn projective_coordinates(&self) -> (Self::Base, Self::Base, Self::Base);
88
89    /// Return the affine coordinates of the point, if it is not at the identity
90    /// (in which case, return None).
91    fn affine_coordinates(&self) -> Option<(Self::Base, Self::Base)>;
92
93    /// Returns an uncompressed byte representation of a curve element.
94    fn to_bytes_uncompressed(&self) -> Vec<u8>;
95
96    /// Returns a compressed byte representation of a curve element.
97    /// TODO!(ryancao): Should we also have a self type representing the
98    /// bitwidth?
99    fn to_bytes_compressed(&self) -> Vec<u8>;
100
101    /// Returns the unique curve element represented by the uncompressed
102    /// bytestring.
103    fn from_bytes_uncompressed(bytes: &[u8]) -> Self;
104
105    /// Returns the unique curve element represented by the compressed
106    /// bytestring.
107    fn from_bytes_compressed(bytes: &[u8]) -> Self;
108
109    /// Returns the group element from x and y coordinates.
110    fn from_xy(x: Self::Base, y: Self::Base) -> Self;
111
112    /// Returns the group element from x coordinate + parity of y.
113    fn from_x_and_sign_y(x: Self::Base, y_sign: u8) -> Self;
114
115    /// An optimized version of scalar multiplication when the scalar element
116    /// fits within 128 bits.
117    fn scalar_mult_unsigned_integer<T: Unsigned + Zero + ToBytes>(&self, scalar: &T) -> Self;
118}
119
120/// TODO(ryancao): Test these implementations!
121impl HasByteRepresentation for Fr {
122    const REPR_NUM_BYTES: usize = 32;
123
124    fn from_bytes_le(bytes: &[u8]) -> Self {
125        if bytes.len() > Self::REPR_NUM_BYTES {
126            panic!("Error: Attempted to convert from greater than 32-length byte vector into Fr")
127        }
128        // Pad with 0s at the most significant bits if less than 32 bytes.
129        let bytes_len_32_slice: [u8; 32] = if bytes.len() < Self::REPR_NUM_BYTES {
130            let padding = vec![0_u8; Self::REPR_NUM_BYTES - bytes.len()];
131            let bytes_owned = bytes.to_owned();
132            bytes_owned
133                .into_iter()
134                .chain(padding)
135                .collect_vec()
136                .try_into()
137                .unwrap()
138        } else {
139            bytes.try_into().unwrap()
140        };
141        Fr::from_bytes(&bytes_len_32_slice).unwrap()
142    }
143
144    fn to_bytes_le(&self) -> Vec<u8> {
145        Fr::to_bytes(self).to_vec()
146    }
147
148    fn to_u64s_le(&self) -> Vec<u64> {
149        let bytes = self.to_bytes_le();
150
151        let fold_bytes = |acc, x: &u8| (acc << 8) + (*x as u64);
152
153        vec![
154            bytes[0..8].iter().rev().fold(0, fold_bytes),
155            bytes[8..16].iter().rev().fold(0, fold_bytes),
156            bytes[16..24].iter().rev().fold(0, fold_bytes),
157            bytes[24..32].iter().rev().fold(0, fold_bytes),
158        ]
159    }
160
161    fn from_u64s_le(words: Vec<u64>) -> Self
162    where
163        Self: Sized,
164    {
165        let mask_8bit = (1_u64 << 8) - 1;
166
167        Self::from_bytes_le(&[
168            (words[0] & mask_8bit) as u8,
169            ((words[0] & (mask_8bit << 8)) >> 8) as u8,
170            ((words[0] & (mask_8bit << 16)) >> 16) as u8,
171            ((words[0] & (mask_8bit << 24)) >> 24) as u8,
172            ((words[0] & (mask_8bit << 32)) >> 32) as u8,
173            ((words[0] & (mask_8bit << 40)) >> 40) as u8,
174            ((words[0] & (mask_8bit << 48)) >> 48) as u8,
175            ((words[0] & (mask_8bit << 56)) >> 56) as u8,
176            (words[1] & mask_8bit) as u8,
177            ((words[1] & (mask_8bit << 8)) >> 8) as u8,
178            ((words[1] & (mask_8bit << 16)) >> 16) as u8,
179            ((words[1] & (mask_8bit << 24)) >> 24) as u8,
180            ((words[1] & (mask_8bit << 32)) >> 32) as u8,
181            ((words[1] & (mask_8bit << 40)) >> 40) as u8,
182            ((words[1] & (mask_8bit << 48)) >> 48) as u8,
183            ((words[1] & (mask_8bit << 56)) >> 56) as u8,
184            (words[2] & mask_8bit) as u8,
185            ((words[2] & (mask_8bit << 8)) >> 8) as u8,
186            ((words[2] & (mask_8bit << 16)) >> 16) as u8,
187            ((words[2] & (mask_8bit << 24)) >> 24) as u8,
188            ((words[2] & (mask_8bit << 32)) >> 32) as u8,
189            ((words[2] & (mask_8bit << 40)) >> 40) as u8,
190            ((words[2] & (mask_8bit << 48)) >> 48) as u8,
191            ((words[2] & (mask_8bit << 56)) >> 56) as u8,
192            (words[3] & mask_8bit) as u8,
193            ((words[3] & (mask_8bit << 8)) >> 8) as u8,
194            ((words[3] & (mask_8bit << 16)) >> 16) as u8,
195            ((words[3] & (mask_8bit << 24)) >> 24) as u8,
196            ((words[3] & (mask_8bit << 32)) >> 32) as u8,
197            ((words[3] & (mask_8bit << 40)) >> 40) as u8,
198            ((words[3] & (mask_8bit << 48)) >> 48) as u8,
199            ((words[3] & (mask_8bit << 56)) >> 56) as u8,
200        ])
201    }
202
203    fn vec_from_bytes_le(bytes: &[u8]) -> Vec<Self>
204    where
205        Self: Sized,
206    {
207        bytes
208            .chunks(Self::REPR_NUM_BYTES)
209            .map(Self::from_bytes_le)
210            .collect()
211    }
212}
213
214impl Zeroizable for Fr {
215    fn zeroize(&mut self) {
216        *self = Fr::ZERO;
217    }
218}
219
220impl HasByteRepresentation for Fq {
221    const REPR_NUM_BYTES: usize = 32;
222
223    fn from_bytes_le(bytes: &[u8]) -> Self {
224        if bytes.len() != Self::REPR_NUM_BYTES {
225            panic!("Error: Attempted to convert from non-32-length byte vector into Fr")
226        }
227        let bytes_len_32_slice: [u8; 32] = bytes.try_into().unwrap();
228        Fq::from_bytes(&bytes_len_32_slice).unwrap()
229    }
230
231    fn to_bytes_le(&self) -> Vec<u8> {
232        Fq::to_bytes(self).to_vec()
233    }
234
235    fn to_u64s_le(&self) -> Vec<u64> {
236        let bytes = self.to_bytes_le();
237
238        let fold_bytes = |acc, x: &u8| (acc << 8) + (*x as u64);
239
240        vec![
241            bytes[0..8].iter().rev().fold(0, fold_bytes),
242            bytes[8..16].iter().rev().fold(0, fold_bytes),
243            bytes[16..24].iter().rev().fold(0, fold_bytes),
244            bytes[24..32].iter().rev().fold(0, fold_bytes),
245        ]
246    }
247
248    fn from_u64s_le(words: Vec<u64>) -> Self
249    where
250        Self: Sized,
251    {
252        let mask_8bit = (1_u64 << 8) - 1;
253
254        Self::from_bytes_le(&[
255            (words[0] & mask_8bit) as u8,
256            ((words[0] & (mask_8bit << 8)) >> 8) as u8,
257            ((words[0] & (mask_8bit << 16)) >> 16) as u8,
258            ((words[0] & (mask_8bit << 24)) >> 24) as u8,
259            ((words[0] & (mask_8bit << 32)) >> 32) as u8,
260            ((words[0] & (mask_8bit << 40)) >> 40) as u8,
261            ((words[0] & (mask_8bit << 48)) >> 48) as u8,
262            ((words[0] & (mask_8bit << 56)) >> 56) as u8,
263            (words[1] & mask_8bit) as u8,
264            ((words[1] & (mask_8bit << 8)) >> 8) as u8,
265            ((words[1] & (mask_8bit << 16)) >> 16) as u8,
266            ((words[1] & (mask_8bit << 24)) >> 24) as u8,
267            ((words[1] & (mask_8bit << 32)) >> 32) as u8,
268            ((words[1] & (mask_8bit << 40)) >> 40) as u8,
269            ((words[1] & (mask_8bit << 48)) >> 48) as u8,
270            ((words[1] & (mask_8bit << 56)) >> 56) as u8,
271            (words[2] & mask_8bit) as u8,
272            ((words[2] & (mask_8bit << 8)) >> 8) as u8,
273            ((words[2] & (mask_8bit << 16)) >> 16) as u8,
274            ((words[2] & (mask_8bit << 24)) >> 24) as u8,
275            ((words[2] & (mask_8bit << 32)) >> 32) as u8,
276            ((words[2] & (mask_8bit << 40)) >> 40) as u8,
277            ((words[2] & (mask_8bit << 48)) >> 48) as u8,
278            ((words[2] & (mask_8bit << 56)) >> 56) as u8,
279            (words[3] & mask_8bit) as u8,
280            ((words[3] & (mask_8bit << 8)) >> 8) as u8,
281            ((words[3] & (mask_8bit << 16)) >> 16) as u8,
282            ((words[3] & (mask_8bit << 24)) >> 24) as u8,
283            ((words[3] & (mask_8bit << 32)) >> 32) as u8,
284            ((words[3] & (mask_8bit << 40)) >> 40) as u8,
285            ((words[3] & (mask_8bit << 48)) >> 48) as u8,
286            ((words[3] & (mask_8bit << 56)) >> 56) as u8,
287        ])
288    }
289
290    fn vec_from_bytes_le(bytes: &[u8]) -> Vec<Self>
291    where
292        Self: Sized,
293    {
294        bytes
295            .chunks(Self::REPR_NUM_BYTES)
296            .map(Self::from_bytes_le)
297            .collect()
298    }
299}
300
301impl Zeroizable for Fq {
302    fn zeroize(&mut self) {
303        *self = Fq::ZERO;
304    }
305}
306
307impl Zeroizable for Bn256 {
308    fn zeroize(&mut self) {
309        *self = Bn256::zero();
310    }
311}
312
313impl PrimeOrderCurve for Bn256 {
314    type Scalar = <Bn256 as CurveExt>::ScalarExt;
315    type Base = <Bn256 as CurveExt>::Base;
316
317    const UNCOMPRESSED_CURVE_POINT_BYTEWIDTH: usize = 65;
318    const COMPRESSED_CURVE_POINT_BYTEWIDTH: usize = 34;
319    const SCALAR_ELEM_BYTEWIDTH: usize = 32;
320
321    fn zero() -> Self {
322        Bn256::identity()
323    }
324
325    fn generator() -> Self {
326        Bn256::generator()
327    }
328
329    fn random(rng: impl RngCore) -> Self {
330        <Bn256 as Group>::random(rng)
331    }
332
333    fn double(&self) -> Self {
334        Group::double(self)
335    }
336
337    fn projective_coordinates(&self) -> (Self::Base, Self::Base, Self::Base) {
338        if let Some((x, y)) = self.affine_coordinates() {
339            let z = Self::Base::one();
340            (x, y, z)
341        } else {
342            // it's the identity element
343            (Self::identity().x, Self::identity().y, Self::identity().z)
344        }
345    }
346
347    fn affine_coordinates(&self) -> Option<(Self::Base, Self::Base)> {
348        if self.z == Self::Base::zero() {
349            None
350        } else {
351            let z_inv = self.z.invert().unwrap();
352            Some((self.x * z_inv, self.y * z_inv))
353        }
354    }
355
356    /// The bytestring representation of the BN256 curve is a `[u8; 65]` with
357    /// the following semantic representation:
358    /// * The first `u8` byte represents whether the point is a point at
359    ///   infinity (in affine coordinates). 1 if it is at infinity, 0 otherwise.
360    /// * The next 32 `u8` bytes represent the x-coordinate of the point in
361    ///   little endian.
362    /// * The next 32 `u8` bytes represent the y-coordinate of the point in
363    ///   little endian.
364    fn to_bytes_uncompressed(&self) -> Vec<u8> {
365        // --- First get the affine coordinates. If `None`, we have a point at
366        // infinity. ---
367        let affine_coords = self.affine_coordinates();
368
369        if let Some((x, y)) = affine_coords {
370            let x_bytes = x.to_bytes();
371            let y_bytes = y.to_bytes();
372            std::iter::once(0_u8)
373                .chain(x_bytes)
374                .chain(y_bytes)
375                .collect_vec()
376        } else {
377            // Point at infinity
378            [1_u8; 65].to_vec()
379        }
380    }
381
382    /// The bytestring representation of the BN256 curve is a `[u8; 34]` with
383    /// the following semantic representation:
384    /// * The first `u8` byte represents whether the point is a point at
385    ///   infinity (in affine coordinates).
386    /// * The next 32 `u8` bytes represent the x-coordinate of the point in
387    ///   little endian.
388    /// * The final `u8` byte represents the sign of the y-coordinate of the
389    ///   point.
390    fn to_bytes_compressed(&self) -> Vec<u8> {
391        // --- First get the affine coordinates. If `None`, we have a point at
392        // infinity. ---
393        let affine_coords = self.affine_coordinates();
394
395        if let Some((x, y)) = affine_coords {
396            let x_bytes = x.to_bytes();
397            // 0 when y < q/2, 1 when y > q/2, where q is the prime order of
398            // base field this is because when y < q/2 and y is a square root,
399            // this means y is even. and if y > q/2 and y is a square root, this
400            // means y is odd. this is exactly what we are computing using & 1.
401            let y_sign = y.to_bytes()[0] & 1;
402            std::iter::once(0_u8)
403                .chain(x_bytes)
404                .chain(std::iter::once(y_sign))
405                .collect_vec()
406        } else {
407            // Point at infinity
408            [1_u8; 34].to_vec()
409        }
410    }
411
412    /// will return the elliptic curve point corresponding to an array of bytes
413    /// that represent an uncompressed point. we represent it as a a normalized
414    /// projective curve point (ie, the x and y coordinates are directly the
415    /// affine coordinates) so the z coordinate is always 1.
416    fn from_bytes_uncompressed(bytes: &[u8]) -> Self {
417        // assert that this is a 65 byte representation since it's uncompressed
418        assert_eq!(bytes.len(), Self::UNCOMPRESSED_CURVE_POINT_BYTEWIDTH);
419        // first check if it is a point at infinity
420        if bytes[0] == 1_u8 {
421            Self::identity()
422        } else {
423            let mut x_bytes_alloc = [0_u8; 32];
424            let x_bytes = &bytes[1..33];
425            x_bytes_alloc.copy_from_slice(x_bytes);
426
427            let mut y_bytes_alloc = [0_u8; 32];
428            let y_bytes = &bytes[33..];
429            y_bytes_alloc.copy_from_slice(y_bytes);
430
431            let x_coord = Self::Base::from_bytes(&x_bytes_alloc).unwrap();
432            let y_coord = Self::Base::from_bytes(&y_bytes_alloc).unwrap();
433            let point = Self {
434                x: x_coord,
435                y: y_coord,
436                z: Self::Base::one(),
437            };
438
439            assert_eq!(point.is_on_curve().unwrap_u8(), 1_u8);
440
441            point
442        }
443    }
444
445    /// will return the elliptic curve point corresponding to an array of bytes
446    /// that represent a compressed point. we represent it as a a normalized
447    /// projective curve point (ie, the x and y coordinates are directly the
448    /// affine coordinates) so the z coordinate is always 1.
449    fn from_bytes_compressed(bytes: &[u8]) -> Self {
450        // first check if it is a point at infinity
451        if bytes[0] == 1_u8 {
452            Self::identity()
453        } else {
454            let mut x_alloc_bytes = [0_u8; 32];
455            x_alloc_bytes.copy_from_slice(&bytes[1..33]);
456            let y_sign_byte: u8 = bytes[33];
457            let x_coord = Self::Base::from_bytes_le(&x_alloc_bytes);
458
459            Self::from_x_and_sign_y(x_coord, y_sign_byte)
460        }
461    }
462
463    /// Returns an elliptic curve point from the x and y coordinates.
464    fn from_xy(x: Self::Base, y: Self::Base) -> Self {
465        let point = Self {
466            x,
467            y,
468            z: Self::Base::ONE,
469        };
470        assert_eq!(point.is_on_curve().unwrap_u8(), 1_u8);
471        point
472    }
473
474    fn from_x_and_sign_y(x: Self::Base, y_sign: u8) -> Self {
475        // Ensure that `y_sign` is either 0 or 1
476        assert!(y_sign == 0 || y_sign == 1);
477
478        // y^2 = x^3 + ax + b
479        let y_square = (x.square() + Self::a()) * x + Self::b();
480        let one_y_sqrt = y_square.sqrt().unwrap();
481
482        // Flip y-sign if needed
483        let y_coord = if (one_y_sqrt.to_bytes()[0] % 2) ^ y_sign == 0 {
484            one_y_sqrt
485        } else {
486            one_y_sqrt.neg()
487        };
488
489        Self {
490            x,
491            y: y_coord,
492            z: Self::Base::one(),
493        }
494    }
495
496    /// Simple double-and-add method for scalar multiplication, optimized for
497    /// when the scalar is representable as an unsigned integer (u8, u16, u32,
498    /// u64, or u128).
499    fn scalar_mult_unsigned_integer<T: Unsigned + Zero + ToBytes>(&self, scalar: &T) -> Self {
500        let scalar_repr = scalar.to_be_bytes();
501        let bits = scalar_repr
502            .as_ref()
503            .iter()
504            .flat_map(|byte| (0..8).rev().map(move |i| ((byte >> i) & 1u8) != 0))
505            .collect_vec();
506
507        let mut acc = Self::identity();
508        for bit in bits {
509            acc = PrimeOrderCurve::double(&acc);
510            // The conditional_select is a constant-time selector using the XOR
511            // method.
512            acc = Self::conditional_select(&acc, &(acc + self), Choice::from(bit as u8));
513        }
514        acc
515    }
516}
517
518/// wrapper needed in order to implement RngCore for the Sha256 digest. we need
519/// this to implement RngCore so that it can be used to generate random group
520/// elements by calling `PrimeOrderCurve::random`
521pub struct Sha3XofReaderWrapper {
522    item: XofReaderCoreWrapper<Shake256ReaderCore>,
523}
524
525impl Sha3XofReaderWrapper {
526    pub fn new(item: XofReaderCoreWrapper<Shake256ReaderCore>) -> Self {
527        Self { item }
528    }
529}
530
531impl RngCore for Sha3XofReaderWrapper {
532    fn next_u32(&mut self) -> u32 {
533        let mut buffer: [u8; 4] = [0; 4];
534        self.item.read(&mut buffer);
535        u32::from_le_bytes(buffer)
536    }
537
538    fn next_u64(&mut self) -> u64 {
539        let mut buffer: [u8; 8] = [0; 8];
540        self.item.read(&mut buffer);
541        u64::from_le_bytes(buffer)
542    }
543
544    fn fill_bytes(&mut self, dest: &mut [u8]) {
545        self.item.read(dest);
546    }
547
548    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
549        self.item.read(dest);
550        Ok(())
551    }
552}
553
554// A Constant Rng for testing, always returns 1.
555pub struct ConstantRng {
556    value: u8,
557}
558
559impl ConstantRng {
560    pub fn new(value: u8) -> Self {
561        Self { value }
562    }
563}
564
565impl RngCore for ConstantRng {
566    fn next_u32(&mut self) -> u32 {
567        self.value += 1;
568        self.value as u32
569    }
570
571    fn next_u64(&mut self) -> u64 {
572        self.value += 1;
573        self.value as u64
574    }
575
576    fn fill_bytes(&mut self, dest: &mut [u8]) {
577        if !dest.is_empty() {
578            dest[0] = self.value;
579            for byte in dest.iter_mut().skip(1) {
580                *byte = 0;
581            }
582        }
583    }
584
585    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
586        self.fill_bytes(dest);
587        Ok(())
588    }
589}
590
591impl CryptoRng for ConstantRng {}