frontend/components/binary_operations/rotate_bits.rs
1//! Implements bit shifting gates.
2
3use shared_types::Field;
4
5use crate::layouter::builder::{CircuitBuilder, NodeRef};
6/// A component that performs bit (wire rotation) using
7/// [crate::layouter::nodes::identity_gate::IdentityGateNode] (i.e., rewire).
8///
9/// TODO: Generalize to a wire shuffle (permute) gate.
10///
11/// Both left and right rotations are supported by using
12/// negative/positive values in the `rotate_amount` parameter. +ve value
13/// means rotate left
14///
15/// This is a rotate instruction, meaning that bits which are shifted
16/// beyond num_vars of the word, are appended to the other side in the
17/// same order.
18///
19/// Requires that the input node has already been verified to contain
20/// binary digits.
21#[derive(Clone, Debug)]
22pub struct RotateNode<F: Field> {
23 output: NodeRef<F>,
24}
25
26impl<F: Field> RotateNode<F> {
27 /// Create a new [crate::components::binary_operations::logical_shift::ShiftNode] that performs a rotation by `rotate_amount` (to the right if
28 /// `rotate_amount > 0` or to the left if `rotate_amount < 0`) on `input` node which contains
29 /// `2^num_vars` binary digits.
30 ///
31 /// # Requires
32 /// `input` is assumed to only contain binary digits (i.e. only values from the set
33 /// `{F::ZERO, F::ONE}` for a field `F`).
34 pub fn new(
35 builder_ref: &mut CircuitBuilder<F>,
36 num_vars: usize,
37 rotate_amount: i32,
38 input: &NodeRef<F>,
39 ) -> Self {
40 // Compute the bit reroutings that effectively shift the
41 // input MLE by the appropriate amount.
42 let rot_wirings = generate_rot_wirings(num_vars, rotate_amount);
43 let output = builder_ref.add_identity_gate_node(input, rot_wirings, num_vars, None);
44
45 Self { output }
46 }
47
48 /// Returns a reference to the node containing the shifted value.
49 pub fn get_output(&self) -> NodeRef<F> {
50 self.output.clone()
51 }
52}
53
54fn generate_rot_wirings(arity: usize, rot_amount: i32) -> Vec<(u32, u32)> {
55 // Ensure `rot_amount` can represent all possible rotations for a
56 // given value of `arity`. Here `rot_amount` is of type `i32` and a
57 // -ve value means left rotation and a +ve value means right
58 // rotation. The amount of rotation
59
60 assert!(arity <= 30);
61
62 let arr_sz = 1 << arity;
63
64 let mod_n = |x: i32| {
65 let v = x % arr_sz;
66 if v < 0 {
67 v + arr_sz
68 } else {
69 v
70 }
71 };
72
73 (0..arr_sz)
74 .map(|i| (mod_n(i + rot_amount) as _, i as _))
75 .collect()
76}
77
78#[cfg(test)]
79mod test {
80 use super::*;
81 use itertools::Itertools;
82
83 #[test]
84 fn test_8bit_right_rotate_by_one() {
85 let shift_wirings = generate_rot_wirings(3, 1);
86
87 assert_eq!(
88 shift_wirings,
89 vec![
90 (1, 0),
91 (2, 1),
92 (3, 2),
93 (4, 3),
94 (5, 4),
95 (6, 5),
96 (7, 6),
97 (0, 7)
98 ]
99 );
100 }
101
102 #[test]
103 fn test_8bit_left_rotate_by_one() {
104 let shift_wirings = generate_rot_wirings(3, -1);
105
106 assert_eq!(
107 shift_wirings,
108 vec![
109 (7, 0),
110 (0, 1),
111 (1, 2),
112 (2, 3),
113 (3, 4),
114 (4, 5),
115 (5, 6),
116 (6, 7)
117 ]
118 );
119 }
120
121 #[test]
122 fn test_256bit_right_rotate_by_rand() {
123 let shift_wirings = generate_rot_wirings(8, 1);
124 let expected = (1..256)
125 .map(|i| (i, i - 1))
126 .chain((255..256).map(|_| (0, 255)))
127 .collect_vec();
128 assert_eq!(shift_wirings, expected);
129 }
130
131 #[test]
132 fn test_256bit_left_rotate_by_rand() {
133 let rot_wirings = generate_rot_wirings(8, -1);
134 let mut expected_wires = vec![(255, 0)];
135 (1..256).for_each(|i| expected_wires.push((i - 1 as u32, i as u32)));
136 assert_eq!(rot_wirings, expected_wires);
137 }
138}