Cipolla's Algorithm
· 2 min · number theory , modular arithmetic , competitive programming
Recently I was solving this problem. It took me a while, and I encountered many issues along the way, but I found parts of the solution really interesting and rewarding.
Since there is no English statement, I will provide a summary here.
You are given a polynomial with degree . You are asked to evaluate a complex expression as a formal power series and output the coefficients of the first terms modulo .
Turns out, many of the operations required are already well documented online (see here), and all you need to do is implement them. However, I was unable to find any comprehensive competitive-programming-focused resources on the square root operation, which I will attempt to explain in this post.
Finding a Recurrence Relation#
Suppose we have a polynomial , and we want to find a polynomial such that . We can define a function:
From here we can use Newton’s Method to get a recurrence relation of the first terms of this power series:
We now need to find a suitable . This is trivial for polynomials with a constant term that is a perfect square. However, if the constant term is not a perfect square, we need to find solutions to the congruence
Cipolla’s Algorithm#
Cipolla’s Algorithm is a technique used to solve congruences of the form:
where is an odd prime.
To check if an integer is a quadratic residue modulo , we can compute its Legendre Symbol and check if it is equal to . Since is an odd prime, this can be found with:
We start off by finding an such that is not a quadratic residue modulo . Since there is no deterministic algorithm to find such , we can sample random until a suitable one is found. The chance that a random is suitable is , for a large prime , making the expected number of guesses ~.
After that, we can compute
where is analagous to a complex number in an extended field.
Implementation#
Here is an implementation of the solution to the original problem in C++:
/* https://cp-algorithms.com/algebra/polynomial.html#calculating-functions-of-polynomial */
#include <bits/stdc++.h>
using namespace std;
#define ll long long
using poly = vector<ll>;
using f2p = array<ll, 2>;
const int mod = 998244353; // c * 2 ^ k + 1
const int root = 15311432; // PrimitiveRoot[mod] ^ c
const int root_1 = 469870224; // root ^ -1 % mod
const int root_pw = 1 << 23; // 2 ^ k
ll f2 = -1; // a^2 - n
ll mul(ll a, ll b) {
return (a * b) % mod;
}
f2p mul(f2p a, f2p b) {
return {
(a[0] * b[0] % mod + a[1] * b[1] % mod * f2 % mod) % mod,
(a[0] * b[1] % mod + a[1] * b[0] % mod) % mod
};
}
template <typename T>
T bexp(T x, ll n, T identity) {
assert(n >= 0);
T res = identity; // identity element
while(n > 0) {
if(n & 1) res = mul(res, x);
x = mul(x, x);
n >>= 1;
}
return res;
}
ll inverse(ll x) {
return bexp(x, mod - 2, 1LL);
}
void ntt(poly &a, bool invert) {
int n = a.size();
for (int i = 1, j = 0; i < n; i++) {
int bit = n >> 1;
for (; j & bit; bit >>= 1)
j ^= bit;
j ^= bit;
if (i < j)
swap(a[i], a[j]);
}
for (int len = 2; len <= n; len <<= 1) {
int wlen = invert ? root_1 : root;
for (int i = len; i < root_pw; i <<= 1)
wlen = (int)(1LL * wlen * wlen % mod);
for (int i = 0; i < n; i += len) {
int w = 1;
for (int j = 0; j < len / 2; j++) {
int u = a[i + j], v = (int)(1LL * a[i + j + len / 2] * w % mod);
a[i + j] = u + v < mod ? u + v : u + v - mod;
a[i + j + len / 2] = u - v >= 0 ? u - v : u - v + mod;
w = (int)(1LL * w * wlen % mod);
}
}
}
if (invert) {
int n_1 = inverse(n);
for (ll &x : a)
x = (int)(1LL * x * n_1 % mod);
}
}
poly mul(poly a, poly b, int N) {
if(a.size() == 1) {
for(int i = 0; i < b.size(); i++) {
b[i] = ((b[i] * a[0]) % mod + mod) % mod;
}
b.resize(N);
return b;
}
poly fa(a.begin(), a.end()), fb(b.begin(), b.end());
int n = 1;
while (n < a.size() + b.size())
n <<= 1;
fa.resize(n);
fb.resize(n);
ntt(fa, false);
ntt(fb, false);
for (int i = 0; i < n; i++)
fa[i] = (fa[i] * fb[i]) % mod;
ntt(fa, true);
poly result(n);
for (int i = 0; i < n; i++)
result[i] = fa[i];
result.resize(N);
return result;
}
// a + b
poly add(poly a, poly b, int N) {
if (a.size() < b.size()) swap(a, b);
for (int i = 0; i < b.size(); i++) {
a[i] = (a[i] + b[i]) % mod;
}
a.resize(N);
return a;
}
// integral of p
poly integral(poly p, int N) {
poly r({0});
for (int i = 0; i < p.size(); i++) {
r.push_back((inverse(i + 1) * p[i]) % mod);
}
r.resize(N);
return r;
}
// p'
poly derivative(poly p, int N) {
poly r;
for (int i = 1; i < p.size(); i++) {
r.push_back((p[i] * i) % mod);
}
r.resize(N);
return r;
}
// a^-1
poly inverse(poly a, int N) {
poly q({inverse(a[0])}), short_a;
int sz = 1;
do {
sz <<= 1;
for (int i = short_a.size(); i < min(sz, (int) a.size()); i++) {
short_a.push_back(a[i]);
}
q = mul(q, add({2}, mul({-1}, mul(short_a, q, sz), sz), sz), sz);
} while (sz < N);
q.resize(N);
return q;
}
// ln a
poly logarithm(poly a, int N) {
poly r = integral(mul(derivative(a, N), inverse(a, N), N), N);
return r;
}
// e^p
poly exp(poly p, int N) {
poly q({1}), short_p;
int sz = 1;
do {
sz <<= 1;
for (int i = short_p.size(); i < min(sz, (int) p.size()); i++) {
short_p.push_back(p[i]);
}
q = mul(q, add({1}, add(short_p, mul({-1}, logarithm(q, sz), sz), sz), sz), sz);
} while (sz < N);
q.resize(N);
return q;
}
// raise p^k
poly pow(poly p, int k, int N) {
poly lnp = logarithm(p, N);
return exp(mul({k}, logarithm(p, N), N), N);
}
// algorithms described in the blog
ll cipollas(ll n) {
// check if is a valid quadratic residue
if(bexp(n, mod >> 1, 1LL) != 1) return -1;
ll a = -1;
while(true) {
// should take ~2 operations
a = rand() % mod;
f2 = (a * a % mod - n + mod) % mod;
if(bexp(f2, mod >> 1, 1LL) != 1) break;
}
f2p cn = {a, 1};
f2p cnp = bexp(cn, (mod + 1) >> 1, {1, 0});
ll ans = (cnp[0] + mod) % mod;
// check if is a valid solution
if(ans * ans % mod == n) return min(ans, mod - ans);
else return -1;
}
poly sqrt(poly p, int N) {
ll f0 = p[0];
ll res = cipollas(f0);
assert(res != -1);
poly q({res});
int sz = 1;
do {
sz <<= 1;
poly den = inverse(mul({2}, q, sz), sz);
poly num = add(mul(q, q, sz), p, sz);
q = mul(den, num, sz);
} while(sz < N);
q.resize(N);
return q;
}
void solve() {
int n;
ll k;
cin >> n >> k;
poly p(n + 1);
int N = n + 1;
for (int i = 0; i <= n; i++) {
cin >> p[i];
}
poly res1 = sqrt(p, N);
poly res2 = inverse(res1, N);
poly res3 = integral(res2, N);
poly res4 = exp(res3, N);
poly res5 = add({2 - p[0]}, add(p, mul({-1}, res4, N), N), N);
poly res6 = logarithm(res5, N);
poly res7 = add({1}, res6, N);
poly res8 = pow(res7, k, N);
poly res9 = derivative(res8, N);
res9.resize(n);
for (int i = 0; i < n; i++) {
cout << res9[i] << ' ';
}
cout << '\n';
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
// freopen("io.out", "w", stdout);
int T = 1;
// cin >> T;
while (T--)
solve();
return 0;
}