1use crate::{ff_field, Zeroizable};
2use crate::{
3 halo2curves::{bn256::G1 as Bn256, CurveExt},
4 Field, HasByteRepresentation,
5};
6use ark_std::rand::{self, RngCore};
7use halo2curves::bn256::{Fq, Fr};
8use itertools::Itertools;
9use num::traits::ToBytes;
10use num::{Unsigned, Zero};
11use rand::CryptoRng;
12use serde::{Deserialize, Serialize};
13use sha3::{
14 digest::{core_api::XofReaderCoreWrapper, XofReader},
15 Shake256ReaderCore,
16};
17use std::{
30 fmt,
31 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
32};
33use subtle::{Choice, ConditionallySelectable};
34use crate::halo2curves::group::Group;
37
38#[cfg(test)]
39mod tests;
41
42pub trait PrimeOrderCurve:
44 Copy
45 + Clone
46 + Sized
47 + Send
48 + Sync
49 + fmt::Debug
50 + Eq
51 + 'static
52 + Neg<Output = Self>
53 + Mul<Self::Scalar, Output = Self>
54 + Add<Self, Output = Self>
55 + Sub<Self, Output = Self>
56 + AddAssign<Self>
57 + SubAssign<Self>
58 + MulAssign<Self::Scalar>
59 + Serialize
60 + for<'de> Deserialize<'de>
61 + Zeroizable
62{
63 type Scalar: Field;
65
66 type Base: Field;
68
69 const UNCOMPRESSED_CURVE_POINT_BYTEWIDTH: usize;
71 const COMPRESSED_CURVE_POINT_BYTEWIDTH: usize;
72 const SCALAR_ELEM_BYTEWIDTH: usize;
73
74 fn zero() -> Self;
76
77 fn generator() -> Self;
79
80 fn random(rng: impl RngCore) -> Self;
82
83 fn double(&self) -> Self;
85
86 fn projective_coordinates(&self) -> (Self::Base, Self::Base, Self::Base);
88
89 fn affine_coordinates(&self) -> Option<(Self::Base, Self::Base)>;
92
93 fn to_bytes_uncompressed(&self) -> Vec<u8>;
95
96 fn to_bytes_compressed(&self) -> Vec<u8>;
100
101 fn from_bytes_uncompressed(bytes: &[u8]) -> Self;
104
105 fn from_bytes_compressed(bytes: &[u8]) -> Self;
108
109 fn from_xy(x: Self::Base, y: Self::Base) -> Self;
111
112 fn from_x_and_sign_y(x: Self::Base, y_sign: u8) -> Self;
114
115 fn scalar_mult_unsigned_integer<T: Unsigned + Zero + ToBytes>(&self, scalar: &T) -> Self;
118}
119
120impl HasByteRepresentation for Fr {
122 const REPR_NUM_BYTES: usize = 32;
123
124 fn from_bytes_le(bytes: &[u8]) -> Self {
125 if bytes.len() > Self::REPR_NUM_BYTES {
126 panic!("Error: Attempted to convert from greater than 32-length byte vector into Fr")
127 }
128 let bytes_len_32_slice: [u8; 32] = if bytes.len() < Self::REPR_NUM_BYTES {
130 let padding = vec![0_u8; Self::REPR_NUM_BYTES - bytes.len()];
131 let bytes_owned = bytes.to_owned();
132 bytes_owned
133 .into_iter()
134 .chain(padding)
135 .collect_vec()
136 .try_into()
137 .unwrap()
138 } else {
139 bytes.try_into().unwrap()
140 };
141 Fr::from_bytes(&bytes_len_32_slice).unwrap()
142 }
143
144 fn to_bytes_le(&self) -> Vec<u8> {
145 Fr::to_bytes(self).to_vec()
146 }
147
148 fn to_u64s_le(&self) -> Vec<u64> {
149 let bytes = self.to_bytes_le();
150
151 let fold_bytes = |acc, x: &u8| (acc << 8) + (*x as u64);
152
153 vec![
154 bytes[0..8].iter().rev().fold(0, fold_bytes),
155 bytes[8..16].iter().rev().fold(0, fold_bytes),
156 bytes[16..24].iter().rev().fold(0, fold_bytes),
157 bytes[24..32].iter().rev().fold(0, fold_bytes),
158 ]
159 }
160
161 fn from_u64s_le(words: Vec<u64>) -> Self
162 where
163 Self: Sized,
164 {
165 let mask_8bit = (1_u64 << 8) - 1;
166
167 Self::from_bytes_le(&[
168 (words[0] & mask_8bit) as u8,
169 ((words[0] & (mask_8bit << 8)) >> 8) as u8,
170 ((words[0] & (mask_8bit << 16)) >> 16) as u8,
171 ((words[0] & (mask_8bit << 24)) >> 24) as u8,
172 ((words[0] & (mask_8bit << 32)) >> 32) as u8,
173 ((words[0] & (mask_8bit << 40)) >> 40) as u8,
174 ((words[0] & (mask_8bit << 48)) >> 48) as u8,
175 ((words[0] & (mask_8bit << 56)) >> 56) as u8,
176 (words[1] & mask_8bit) as u8,
177 ((words[1] & (mask_8bit << 8)) >> 8) as u8,
178 ((words[1] & (mask_8bit << 16)) >> 16) as u8,
179 ((words[1] & (mask_8bit << 24)) >> 24) as u8,
180 ((words[1] & (mask_8bit << 32)) >> 32) as u8,
181 ((words[1] & (mask_8bit << 40)) >> 40) as u8,
182 ((words[1] & (mask_8bit << 48)) >> 48) as u8,
183 ((words[1] & (mask_8bit << 56)) >> 56) as u8,
184 (words[2] & mask_8bit) as u8,
185 ((words[2] & (mask_8bit << 8)) >> 8) as u8,
186 ((words[2] & (mask_8bit << 16)) >> 16) as u8,
187 ((words[2] & (mask_8bit << 24)) >> 24) as u8,
188 ((words[2] & (mask_8bit << 32)) >> 32) as u8,
189 ((words[2] & (mask_8bit << 40)) >> 40) as u8,
190 ((words[2] & (mask_8bit << 48)) >> 48) as u8,
191 ((words[2] & (mask_8bit << 56)) >> 56) as u8,
192 (words[3] & mask_8bit) as u8,
193 ((words[3] & (mask_8bit << 8)) >> 8) as u8,
194 ((words[3] & (mask_8bit << 16)) >> 16) as u8,
195 ((words[3] & (mask_8bit << 24)) >> 24) as u8,
196 ((words[3] & (mask_8bit << 32)) >> 32) as u8,
197 ((words[3] & (mask_8bit << 40)) >> 40) as u8,
198 ((words[3] & (mask_8bit << 48)) >> 48) as u8,
199 ((words[3] & (mask_8bit << 56)) >> 56) as u8,
200 ])
201 }
202
203 fn vec_from_bytes_le(bytes: &[u8]) -> Vec<Self>
204 where
205 Self: Sized,
206 {
207 bytes
208 .chunks(Self::REPR_NUM_BYTES)
209 .map(Self::from_bytes_le)
210 .collect()
211 }
212}
213
214impl Zeroizable for Fr {
215 fn zeroize(&mut self) {
216 *self = Fr::ZERO;
217 }
218}
219
220impl HasByteRepresentation for Fq {
221 const REPR_NUM_BYTES: usize = 32;
222
223 fn from_bytes_le(bytes: &[u8]) -> Self {
224 if bytes.len() != Self::REPR_NUM_BYTES {
225 panic!("Error: Attempted to convert from non-32-length byte vector into Fr")
226 }
227 let bytes_len_32_slice: [u8; 32] = bytes.try_into().unwrap();
228 Fq::from_bytes(&bytes_len_32_slice).unwrap()
229 }
230
231 fn to_bytes_le(&self) -> Vec<u8> {
232 Fq::to_bytes(self).to_vec()
233 }
234
235 fn to_u64s_le(&self) -> Vec<u64> {
236 let bytes = self.to_bytes_le();
237
238 let fold_bytes = |acc, x: &u8| (acc << 8) + (*x as u64);
239
240 vec![
241 bytes[0..8].iter().rev().fold(0, fold_bytes),
242 bytes[8..16].iter().rev().fold(0, fold_bytes),
243 bytes[16..24].iter().rev().fold(0, fold_bytes),
244 bytes[24..32].iter().rev().fold(0, fold_bytes),
245 ]
246 }
247
248 fn from_u64s_le(words: Vec<u64>) -> Self
249 where
250 Self: Sized,
251 {
252 let mask_8bit = (1_u64 << 8) - 1;
253
254 Self::from_bytes_le(&[
255 (words[0] & mask_8bit) as u8,
256 ((words[0] & (mask_8bit << 8)) >> 8) as u8,
257 ((words[0] & (mask_8bit << 16)) >> 16) as u8,
258 ((words[0] & (mask_8bit << 24)) >> 24) as u8,
259 ((words[0] & (mask_8bit << 32)) >> 32) as u8,
260 ((words[0] & (mask_8bit << 40)) >> 40) as u8,
261 ((words[0] & (mask_8bit << 48)) >> 48) as u8,
262 ((words[0] & (mask_8bit << 56)) >> 56) as u8,
263 (words[1] & mask_8bit) as u8,
264 ((words[1] & (mask_8bit << 8)) >> 8) as u8,
265 ((words[1] & (mask_8bit << 16)) >> 16) as u8,
266 ((words[1] & (mask_8bit << 24)) >> 24) as u8,
267 ((words[1] & (mask_8bit << 32)) >> 32) as u8,
268 ((words[1] & (mask_8bit << 40)) >> 40) as u8,
269 ((words[1] & (mask_8bit << 48)) >> 48) as u8,
270 ((words[1] & (mask_8bit << 56)) >> 56) as u8,
271 (words[2] & mask_8bit) as u8,
272 ((words[2] & (mask_8bit << 8)) >> 8) as u8,
273 ((words[2] & (mask_8bit << 16)) >> 16) as u8,
274 ((words[2] & (mask_8bit << 24)) >> 24) as u8,
275 ((words[2] & (mask_8bit << 32)) >> 32) as u8,
276 ((words[2] & (mask_8bit << 40)) >> 40) as u8,
277 ((words[2] & (mask_8bit << 48)) >> 48) as u8,
278 ((words[2] & (mask_8bit << 56)) >> 56) as u8,
279 (words[3] & mask_8bit) as u8,
280 ((words[3] & (mask_8bit << 8)) >> 8) as u8,
281 ((words[3] & (mask_8bit << 16)) >> 16) as u8,
282 ((words[3] & (mask_8bit << 24)) >> 24) as u8,
283 ((words[3] & (mask_8bit << 32)) >> 32) as u8,
284 ((words[3] & (mask_8bit << 40)) >> 40) as u8,
285 ((words[3] & (mask_8bit << 48)) >> 48) as u8,
286 ((words[3] & (mask_8bit << 56)) >> 56) as u8,
287 ])
288 }
289
290 fn vec_from_bytes_le(bytes: &[u8]) -> Vec<Self>
291 where
292 Self: Sized,
293 {
294 bytes
295 .chunks(Self::REPR_NUM_BYTES)
296 .map(Self::from_bytes_le)
297 .collect()
298 }
299}
300
301impl Zeroizable for Fq {
302 fn zeroize(&mut self) {
303 *self = Fq::ZERO;
304 }
305}
306
307impl Zeroizable for Bn256 {
308 fn zeroize(&mut self) {
309 *self = Bn256::zero();
310 }
311}
312
313impl PrimeOrderCurve for Bn256 {
314 type Scalar = <Bn256 as CurveExt>::ScalarExt;
315 type Base = <Bn256 as CurveExt>::Base;
316
317 const UNCOMPRESSED_CURVE_POINT_BYTEWIDTH: usize = 65;
318 const COMPRESSED_CURVE_POINT_BYTEWIDTH: usize = 34;
319 const SCALAR_ELEM_BYTEWIDTH: usize = 32;
320
321 fn zero() -> Self {
322 Bn256::identity()
323 }
324
325 fn generator() -> Self {
326 Bn256::generator()
327 }
328
329 fn random(rng: impl RngCore) -> Self {
330 <Bn256 as Group>::random(rng)
331 }
332
333 fn double(&self) -> Self {
334 Group::double(self)
335 }
336
337 fn projective_coordinates(&self) -> (Self::Base, Self::Base, Self::Base) {
338 if let Some((x, y)) = self.affine_coordinates() {
339 let z = Self::Base::one();
340 (x, y, z)
341 } else {
342 (Self::identity().x, Self::identity().y, Self::identity().z)
344 }
345 }
346
347 fn affine_coordinates(&self) -> Option<(Self::Base, Self::Base)> {
348 if self.z == Self::Base::zero() {
349 None
350 } else {
351 let z_inv = self.z.invert().unwrap();
352 Some((self.x * z_inv, self.y * z_inv))
353 }
354 }
355
356 fn to_bytes_uncompressed(&self) -> Vec<u8> {
365 let affine_coords = self.affine_coordinates();
368
369 if let Some((x, y)) = affine_coords {
370 let x_bytes = x.to_bytes();
371 let y_bytes = y.to_bytes();
372 std::iter::once(0_u8)
373 .chain(x_bytes)
374 .chain(y_bytes)
375 .collect_vec()
376 } else {
377 [1_u8; 65].to_vec()
379 }
380 }
381
382 fn to_bytes_compressed(&self) -> Vec<u8> {
391 let affine_coords = self.affine_coordinates();
394
395 if let Some((x, y)) = affine_coords {
396 let x_bytes = x.to_bytes();
397 let y_sign = y.to_bytes()[0] & 1;
402 std::iter::once(0_u8)
403 .chain(x_bytes)
404 .chain(std::iter::once(y_sign))
405 .collect_vec()
406 } else {
407 [1_u8; 34].to_vec()
409 }
410 }
411
412 fn from_bytes_uncompressed(bytes: &[u8]) -> Self {
417 assert_eq!(bytes.len(), Self::UNCOMPRESSED_CURVE_POINT_BYTEWIDTH);
419 if bytes[0] == 1_u8 {
421 Self::identity()
422 } else {
423 let mut x_bytes_alloc = [0_u8; 32];
424 let x_bytes = &bytes[1..33];
425 x_bytes_alloc.copy_from_slice(x_bytes);
426
427 let mut y_bytes_alloc = [0_u8; 32];
428 let y_bytes = &bytes[33..];
429 y_bytes_alloc.copy_from_slice(y_bytes);
430
431 let x_coord = Self::Base::from_bytes(&x_bytes_alloc).unwrap();
432 let y_coord = Self::Base::from_bytes(&y_bytes_alloc).unwrap();
433 let point = Self {
434 x: x_coord,
435 y: y_coord,
436 z: Self::Base::one(),
437 };
438
439 assert_eq!(point.is_on_curve().unwrap_u8(), 1_u8);
440
441 point
442 }
443 }
444
445 fn from_bytes_compressed(bytes: &[u8]) -> Self {
450 if bytes[0] == 1_u8 {
452 Self::identity()
453 } else {
454 let mut x_alloc_bytes = [0_u8; 32];
455 x_alloc_bytes.copy_from_slice(&bytes[1..33]);
456 let y_sign_byte: u8 = bytes[33];
457 let x_coord = Self::Base::from_bytes_le(&x_alloc_bytes);
458
459 Self::from_x_and_sign_y(x_coord, y_sign_byte)
460 }
461 }
462
463 fn from_xy(x: Self::Base, y: Self::Base) -> Self {
465 let point = Self {
466 x,
467 y,
468 z: Self::Base::ONE,
469 };
470 assert_eq!(point.is_on_curve().unwrap_u8(), 1_u8);
471 point
472 }
473
474 fn from_x_and_sign_y(x: Self::Base, y_sign: u8) -> Self {
475 assert!(y_sign == 0 || y_sign == 1);
477
478 let y_square = (x.square() + Self::a()) * x + Self::b();
480 let one_y_sqrt = y_square.sqrt().unwrap();
481
482 let y_coord = if (one_y_sqrt.to_bytes()[0] % 2) ^ y_sign == 0 {
484 one_y_sqrt
485 } else {
486 one_y_sqrt.neg()
487 };
488
489 Self {
490 x,
491 y: y_coord,
492 z: Self::Base::one(),
493 }
494 }
495
496 fn scalar_mult_unsigned_integer<T: Unsigned + Zero + ToBytes>(&self, scalar: &T) -> Self {
500 let scalar_repr = scalar.to_be_bytes();
501 let bits = scalar_repr
502 .as_ref()
503 .iter()
504 .flat_map(|byte| (0..8).rev().map(move |i| ((byte >> i) & 1u8) != 0))
505 .collect_vec();
506
507 let mut acc = Self::identity();
508 for bit in bits {
509 acc = PrimeOrderCurve::double(&acc);
510 acc = Self::conditional_select(&acc, &(acc + self), Choice::from(bit as u8));
513 }
514 acc
515 }
516}
517
518pub struct Sha3XofReaderWrapper {
522 item: XofReaderCoreWrapper<Shake256ReaderCore>,
523}
524
525impl Sha3XofReaderWrapper {
526 pub fn new(item: XofReaderCoreWrapper<Shake256ReaderCore>) -> Self {
527 Self { item }
528 }
529}
530
531impl RngCore for Sha3XofReaderWrapper {
532 fn next_u32(&mut self) -> u32 {
533 let mut buffer: [u8; 4] = [0; 4];
534 self.item.read(&mut buffer);
535 u32::from_le_bytes(buffer)
536 }
537
538 fn next_u64(&mut self) -> u64 {
539 let mut buffer: [u8; 8] = [0; 8];
540 self.item.read(&mut buffer);
541 u64::from_le_bytes(buffer)
542 }
543
544 fn fill_bytes(&mut self, dest: &mut [u8]) {
545 self.item.read(dest);
546 }
547
548 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
549 self.item.read(dest);
550 Ok(())
551 }
552}
553
554pub struct ConstantRng {
556 value: u8,
557}
558
559impl ConstantRng {
560 pub fn new(value: u8) -> Self {
561 Self { value }
562 }
563}
564
565impl RngCore for ConstantRng {
566 fn next_u32(&mut self) -> u32 {
567 self.value += 1;
568 self.value as u32
569 }
570
571 fn next_u64(&mut self) -> u64 {
572 self.value += 1;
573 self.value as u64
574 }
575
576 fn fill_bytes(&mut self, dest: &mut [u8]) {
577 if !dest.is_empty() {
578 dest[0] = self.value;
579 for byte in dest.iter_mut().skip(1) {
580 *byte = 0;
581 }
582 }
583 }
584
585 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
586 self.fill_bytes(dest);
587 Ok(())
588 }
589}
590
591impl CryptoRng for ConstantRng {}