frontend/components/binary_operations/
logical_shift.rs1use std::cmp::{max, min};
4
5use itertools::Itertools;
6use shared_types::Field;
7
8use crate::layouter::builder::{CircuitBuilder, NodeRef};
9
10#[derive(Clone, Debug)]
20pub struct ShiftNode<F: Field> {
21 output: NodeRef<F>,
22}
23
24impl<F: Field> ShiftNode<F> {
25 pub fn new(
33 builder_ref: &mut CircuitBuilder<F>,
34 num_vars: usize,
35 shift_amount: i32,
36 input: &NodeRef<F>,
37 ) -> Self {
38 let shift_wirings = generate_shift_wirings(num_vars, shift_amount);
41 let output = builder_ref.add_identity_gate_node(input, shift_wirings, num_vars, None);
42
43 Self { output }
44 }
45
46 pub fn get_output(&self) -> NodeRef<F> {
48 self.output.clone()
49 }
50}
51
52fn generate_shift_wirings(num_vars: usize, shift_amount: i32) -> Vec<(u32, u32)> {
53 assert!(num_vars <= 30);
61
62 let shift_amount = max(min(shift_amount, 1 << num_vars), -(1 << num_vars));
64
65 if shift_amount >= 0 {
66 let shift_amount = shift_amount as u32;
67 (0..(1 << num_vars) - shift_amount)
68 .map(|i| (i + shift_amount, i))
69 .collect_vec()
70 } else {
71 let shift_amount: u32 = shift_amount.unsigned_abs();
72 (0..(1 << num_vars) - shift_amount)
73 .map(|i| (i, i + shift_amount))
74 .collect_vec()
75 }
76}
77
78#[cfg(test)]
79mod test {
80 use super::*;
81
82 #[test]
83 fn test_8bit_right_shift_by_1() {
84 let shift_wirings = generate_shift_wirings(3, 1);
85
86 assert_eq!(
87 shift_wirings,
88 vec![(1, 0), (2, 1), (3, 2), (4, 3), (5, 4), (6, 5), (7, 6)]
89 );
90 }
91
92 #[test]
93 fn test_256bit_right_shift_by_1() {
94 let shift_wirings = generate_shift_wirings(8, 1);
95
96 assert_eq!(
97 shift_wirings,
98 (0..255).into_iter().map(|i| (i + 1, i)).collect_vec()
99 );
100 }
101
102 #[test]
103 fn test_8bit_left_shift_by_1() {
104 let shift_wirings = generate_shift_wirings(3, -1);
105
106 assert_eq!(
107 shift_wirings,
108 vec![(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)]
109 );
110 }
111
112 #[test]
113 fn test_256bit_left_shift_by_1() {
114 let shift_wirings = generate_shift_wirings(8, -1);
115
116 assert_eq!(
117 shift_wirings,
118 (0..255).into_iter().map(|i| (i, i + 1)).collect_vec()
119 );
120 }
121
122 #[test]
123 fn test_zero_shift() {
124 let shift_wirings = generate_shift_wirings(3, 0);
125 assert_eq!(
128 shift_wirings,
129 vec![
130 (0, 0),
131 (1, 1),
132 (2, 2),
133 (3, 3),
134 (4, 4),
135 (5, 5),
136 (6, 6),
137 (7, 7)
138 ]
139 );
140 }
141
142 #[test]
143 fn test_8bit_right_shift_by_bit_length() {
144 let shift_wirings = generate_shift_wirings(3, 1 << 3);
145
146 assert_eq!(shift_wirings, vec![]);
147 }
148
149 #[test]
150 fn test_8bit_right_shift_by_more_than_bit_length() {
151 let shift_wirings = generate_shift_wirings(3, 100);
152
153 assert_eq!(shift_wirings, vec![]);
154 }
155}