Skip to content

Commit 96920d7

Browse files
committed
Safe modular arithmetic and number theoretic transform
1 parent a21599d commit 96920d7

File tree

5 files changed

+254
-77
lines changed

5 files changed

+254
-77
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
language: rust
22
rust:
3-
- 1.31.1 # Version currently supported by Codeforces
3+
# - 1.31.1 # Version currently supported by Codeforces
44
- stable
55
- beta
66
- nightly

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@ Most competition programmers rely on C++ for its fast execution time. However, i
2121

2222
To my delight, I found that Rust eliminates entire classes of bugs, while reducing visual clutter to make the rest easier to spot. And, it's *fast*. There's a learning curve, to be sure. However, a proficient Rust programmer stands to gain a competitive advantage as well as a more pleasant experience!
2323

24+
For help in getting started, you may check out [some of my past submissions](https://codeforces.com/contest/1168/submission/54859899).
25+
2426
Other online judges that support Rust:
2527
- [Timus](http://acm.timus.ru/help.aspx?topic=rust)
2628
- [SPOJ](http://www.spoj.com/)
2729

2830
## Programming Language Advocacy
2931

30-
My other goal is to appeal to developers who feel limited by mainstream, arguably outdated, programming languages. Perhaps we have a better way.
32+
My other goal is to appeal to developers who feel limited by mainstream, arguably outdated, programming languages. Perhaps we have a better option.
3133

3234
Rather than try to persuade you with words, this repository aims to show by example. If you're new to Rust, see [Jim Blandy's *Why Rust?*](http://www.oreilly.com/programming/free/files/why-rust.pdf) for a brief introduction, or just [dive in!](https://doc.rust-lang.org/book/2018-edition)
3335

@@ -37,6 +39,8 @@ Rather than try to persuade you with words, this repository aims to show by exam
3739
- [Network flows](src/graph/flow.rs): Dinic's blocking flow, Hopcroft-Karp bipartite matching, min cost max flow
3840
- [Connected components](src/graph/connectivity.rs): 2-edge-, 2-vertex- and strongly connected components, bridges, articulation points, topological sort, 2-SAT
3941
- [Associative range query](src/arq_tree.rs): known colloquially as *segtrees*
40-
- [Math](src/math/): Euclid's GCD algorithm, Bezout's identity, rational and complex numbers, fast Fourier transform
42+
- [GCD Math](src/math/mod.rs): canonical solution to Bezout's identity
43+
- [Arithmetic](src/math/num.rs): rational and complex numbers, safe modular arithmetic
44+
- [FFT](src/math/fft.rs): fast Fourier transform, number theoretic transform, convolution
4145
- [Scanner](src/scanner.rs): utility for reading input data
4246
- [String processing](src/string_proc.rs): Knuth-Morris-Pratt string matching, suffix arrays, Manacher's palindrome search

src/math/fft.rs

Lines changed: 129 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! The Fast Fourier Transform (FFT)
2-
use super::num::{Complex, PI};
2+
use super::num::{Complex, Field, PI};
3+
use std::ops::{Add, Div, Mul, Neg, Sub};
34

45
// We can delete this struct once f64::reverse_bits() stabilizes.
56
struct BitRevIterator {
@@ -29,25 +30,95 @@ impl Iterator for BitRevIterator {
2930
}
3031
}
3132

32-
// Integer FFT notes: see problem 30-6 in CLRS for details, noting that
33-
// 15311432 and 469870224 are inverses and 2^23rd roots of 1 mod p=(119<<23)+1
34-
// 440564289 and 1713844692 are inverses and 2^27th roots of 1 mod p=(15<<27)+1
35-
// 125 and 2267742733 are inverses and 2^30th roots of 1 mod p=(3<<30)+1
33+
pub trait FFT: Sized + Copy {
34+
type F: Sized
35+
+ Copy
36+
+ From<Self>
37+
+ Neg
38+
+ Add<Output = Self::F>
39+
+ Div<Output = Self::F>
40+
+ Mul<Output = Self::F>
41+
+ Sub<Output = Self::F>;
42+
43+
const ZERO: Self;
44+
45+
/// A primitive nth root of one raised to the powers 0, 1, 2, ..., n/2 - 1
46+
fn get_roots(n: usize, inverse: bool) -> Vec<Self::F>;
47+
/// 1 for forward transform, 1/n for inverse transform
48+
fn get_factor(n: usize, inverse: bool) -> Self::F;
49+
/// The inverse of Self::F::from()
50+
fn extract(f: Self::F) -> Self;
51+
}
52+
53+
impl FFT for f64 {
54+
type F = Complex;
55+
56+
const ZERO: f64 = 0.0;
57+
58+
fn get_roots(n: usize, inverse: bool) -> Vec<Self::F> {
59+
let step = if inverse { -2.0 } else { 2.0 } * PI / n as f64;
60+
(0..n / 2)
61+
.map(|i| Complex::from_polar(1.0, step * i as f64))
62+
.collect()
63+
}
64+
65+
fn get_factor(n: usize, inverse: bool) -> Self::F {
66+
Self::F::from(if inverse { (n as f64).recip() } else { 1.0 })
67+
}
68+
69+
fn extract(f: Self::F) -> f64 {
70+
f.real
71+
}
72+
}
73+
74+
// NTT notes: see problem 30-6 in CLRS for details, keeping in mind that
75+
// 15311432 and 469870224 are inverses and 2^23rd roots of 1 mod (119<<23)+1
76+
// 440564289 and 1713844692 are inverses and 2^27th roots of 1 mod (15<<27)+1
77+
// 125 and 2267742733 are inverses and 2^30th roots of 1 mod (3<<30)+1
78+
impl FFT for u64 {
79+
type F = Field;
80+
81+
const ZERO: u64 = 0;
82+
83+
fn get_roots(n: usize, inverse: bool) -> Vec<Self::F> {
84+
assert!(n <= 1 << 23);
85+
let mut prim_root = Self::F::from(15_311_432);
86+
if inverse {
87+
prim_root = prim_root.recip();
88+
}
89+
for _ in (0..).take_while(|&i| n < 1 << (23 - i)) {
90+
prim_root = prim_root * prim_root;
91+
}
92+
93+
let mut roots = Vec::with_capacity(n / 2);
94+
let mut root = Self::F::from(1);
95+
for _ in 0..roots.capacity() {
96+
roots.push(root);
97+
root = root * prim_root;
98+
}
99+
roots
100+
}
101+
102+
fn get_factor(n: usize, inverse: bool) -> Self::F {
103+
Self::F::from(if inverse { n as u64 } else { 1 }).recip()
104+
}
105+
106+
fn extract(f: Self::F) -> u64 {
107+
f.val
108+
}
109+
}
36110

37111
/// Computes the discrete fourier transform of v, whose length is a power of 2.
38112
/// Forward transform: polynomial coefficients -> evaluate at roots of unity
39113
/// Inverse transform: values at roots of unity -> interpolated coefficients
40-
pub fn fft(v: &[Complex], inverse: bool) -> Vec<Complex> {
114+
pub fn fft<T: FFT>(v: &[T::F], inverse: bool) -> Vec<T::F> {
41115
let n = v.len();
42116
assert!(n.is_power_of_two());
43117

44-
let step = if inverse { -2.0 } else { 2.0 } * PI / n as f64;
45-
let factor = Complex::from(if inverse { n as f64 } else { 1.0 });
46-
let roots_of_unity = (0..n / 2)
47-
.map(|i| Complex::from_polar(1.0, step * i as f64))
48-
.collect::<Vec<_>>();
118+
let factor = T::get_factor(n, inverse);
119+
let roots_of_unity = T::get_roots(n, inverse);
49120
let mut dft = BitRevIterator::new(n)
50-
.map(|i| v[i] / factor)
121+
.map(|i| v[i] * factor)
51122
.collect::<Vec<_>>();
52123

53124
for m in (0..).map(|s| 1 << s).take_while(|&m| m < n) {
@@ -63,33 +134,36 @@ pub fn fft(v: &[Complex], inverse: bool) -> Vec<Complex> {
63134
dft
64135
}
65136

66-
/// From a real vector, computes a DFT of size at least desired_len
67-
pub fn dft_from_reals(v: &[f64], desired_len: usize) -> Vec<Complex> {
137+
/// From a slice of reals (f64 or u64), computes DFT of size at least desired_len
138+
pub fn dft_from_reals<T: FFT>(v: &[T], desired_len: usize) -> Vec<T::F> {
68139
assert!(v.len() <= desired_len);
140+
69141
let complex_v = v
70142
.iter()
71143
.cloned()
72-
.chain(std::iter::repeat(0.0))
144+
.chain(std::iter::repeat(T::ZERO))
73145
.take(desired_len.next_power_of_two())
74-
.map(Complex::from)
146+
.map(T::F::from)
75147
.collect::<Vec<_>>();
76-
fft(&complex_v, false)
148+
fft::<T>(&complex_v, false)
77149
}
78150

79151
/// The inverse of dft_from_reals()
80-
pub fn idft_to_reals(dft_v: &[Complex], desired_len: usize) -> Vec<f64> {
152+
pub fn idft_to_reals<T: FFT>(dft_v: &[T::F], desired_len: usize) -> Vec<T> {
81153
assert!(dft_v.len() >= desired_len);
82-
let complex_v = fft(dft_v, true);
154+
155+
let complex_v = fft::<T>(dft_v, true);
83156
complex_v
84157
.into_iter()
85158
.take(desired_len)
86-
.map(|c| c.real) // to get integers: c.real.round() as i64
159+
.map(T::extract)
87160
.collect()
88161
}
89162

90-
/// Given two polynomials (vectors) sum_i a[i]x^i and sum_i b[i]x^i,
91-
/// computes their product (convolution) c[k] = sum_(i+j=k) a[i]*b[j]
92-
pub fn convolution(a: &[f64], b: &[f64]) -> Vec<f64> {
163+
/// Given two polynomials (vectors) sum_i a[i] x^i and sum_i b[i] x^i,
164+
/// computes their product (convolution) c[k] = sum_(i+j=k) a[i]*b[j].
165+
/// Uses complex FFT if inputs are f64, or modular NTT if inputs are u64.
166+
pub fn convolution<T: FFT>(a: &[T], b: &[T]) -> Vec<T> {
93167
let len_c = a.len() + b.len() - 1;
94168
let dft_a = dft_from_reals(a, len_c).into_iter();
95169
let dft_b = dft_from_reals(b, len_c).into_iter();
@@ -102,10 +176,10 @@ mod test {
102176
use super::*;
103177

104178
#[test]
105-
fn test_dft() {
179+
fn test_complex_dft() {
106180
let v = vec![7.0, 1.0, 1.0];
107181
let dft_v = dft_from_reals(&v, v.len());
108-
let new_v = idft_to_reals(&dft_v, v.len());
182+
let new_v: Vec<f64> = idft_to_reals(&dft_v, v.len());
109183

110184
let six = Complex::from(6.0);
111185
let seven = Complex::from(7.0);
@@ -117,11 +191,40 @@ mod test {
117191
}
118192

119193
#[test]
120-
fn test_convolution() {
194+
fn test_modular_dft() {
195+
let v = vec![7, 1, 1];
196+
let dft_v = dft_from_reals(&v, v.len());
197+
let new_v: Vec<u64> = idft_to_reals(&dft_v, v.len());
198+
199+
let seven = Field::from(7);
200+
let one = Field::from(1);
201+
let prim = Field::from(15_311_432).pow(1 << 21);
202+
let prim2 = prim * prim;
203+
204+
let eval0 = seven + one + one;
205+
let eval1 = seven + prim + prim2;
206+
let eval2 = seven + prim2 + one;
207+
let eval3 = seven + prim.recip() + prim2;
208+
209+
assert_eq!(dft_v, vec![eval0, eval1, eval2, eval3]);
210+
assert_eq!(new_v, v);
211+
}
212+
213+
#[test]
214+
fn test_complex_convolution() {
121215
let x = vec![2.0, 3.0, 2.0];
122216
let y = vec![7.0, 2.0];
123217
let z = convolution(&x, &y);
124218

125219
assert_eq!(z, vec![14.0, 25.0, 20.0, 4.0]);
126220
}
221+
222+
#[test]
223+
fn test_modular_convolution() {
224+
let x = vec![2, 3, 2];
225+
let y = vec![7, 2];
226+
let z = convolution(&x, &y);
227+
228+
assert_eq!(z, vec![14, 25, 20, 4]);
229+
}
127230
}

src/math/mod.rs

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,13 @@
22
pub mod fft;
33
pub mod num;
44

5-
/// Modular exponentiation by repeated squaring: returns base^exp % m.
6-
///
7-
/// # Panics
8-
///
9-
/// Panics if m == 0. May panic on overflow if m * m > 2^63.
10-
pub fn mod_pow(mut base: u64, mut exp: u64, m: u64) -> u64 {
11-
let mut result = 1 % m;
12-
while exp > 0 {
13-
if exp % 2 == 1 {
14-
result = (result * base) % m;
15-
}
16-
base = (base * base) % m;
17-
exp /= 2;
18-
}
19-
result
20-
}
21-
225
/// Finds (d, coef_a, coef_b) such that d = gcd(a, b) = a * coef_a + b * coef_b.
236
pub fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) {
247
if b == 0 {
258
(a.abs(), a.signum(), 0)
269
} else {
27-
let (d, coef_a, coef_b) = extended_gcd(b, a % b);
28-
(d, coef_b, coef_a - coef_b * (a / b))
10+
let (d, coef_b, coef_a) = extended_gcd(b, a % b);
11+
(d, coef_a, coef_b - coef_a * (a / b))
2912
}
3013
}
3114

@@ -50,17 +33,6 @@ pub fn canon_egcd(a: i64, b: i64, c: i64) -> Option<(i64, i64, i64)> {
5033
mod test {
5134
use super::*;
5235

53-
#[test]
54-
fn test_mod_inverse() {
55-
let p = 1_000_000_007;
56-
let base = 31;
57-
58-
let base_inv = mod_pow(base, p - 2, p);
59-
let identity = (base * base_inv) % p;
60-
61-
assert_eq!(identity, 1);
62-
}
63-
6436
#[test]
6537
fn test_egcd() {
6638
let (a, b) = (14, 35);

0 commit comments

Comments
 (0)