Andy Chen

← All posts

Cipolla's Algorithm

· 2 min · number theory , modular arithmetic , competitive programming


Recently I was solving this problem. It took me a while, and I encountered many issues along the way, but I found parts of the solution really interesting and rewarding.

Since there is no English statement, I will provide a summary here.

You are given a polynomial FF with degree nn. You are asked to evaluate a complex expression as a formal power series and output the coefficients of the first nn terms modulo 998244353998244353.

Turns out, many of the operations required are already well documented online (see here), and all you need to do is implement them. However, I was unable to find any comprehensive competitive-programming-focused resources on the square root operation, which I will attempt to explain in this post.

Finding a Recurrence Relation#

Suppose we have a polynomial PP, and we want to find a polynomial QQ such that Q2=PQ^2 = P. We can define a function:

F(Q)=Q2PF(Q) = Q^2 - P

F(Q)=2QF'(Q) = 2Q

From here we can use Newton’s Method to get a recurrence relation of the first 2K2^K terms of this power series:

Qk+1Qk2+P2Qk(modx2k+1)\boxed{Q_{k + 1} \equiv \frac{{Q_k}^2 + P}{2{Q_k}} \pmod{x^{2^{k+1}}}}

We now need to find a suitable Q0Q_0. This is trivial for polynomials with a constant term that is a perfect square. However, if the constant term is not a perfect square, we need to find solutions to the congruence x2a(modp)x^2 \equiv a \pmod{p}

Cipolla’s Algorithm#

Cipolla’s Algorithm is a technique used to solve congruences of the form:

x2n(modp)x^2 \equiv n \pmod{p}

where pp is an odd prime.

To check if an integer nn is a quadratic residue modulo pp, we can compute its Legendre Symbol and check if it is equal to 11. Since pp is an odd prime, this can be found with:

n(p1)/2 mod p\displaystyle{n^{(p-1)/2}\ \mathrm{mod}\ {p}}

We start off by finding an aa such that a2na^2 - n is not a quadratic residue modulo pp. Since there is no deterministic algorithm to find such aa, we can sample random aa until a suitable one is found. The chance that a random aa is suitable is p12p12\frac{p-1}{2p} \approx \frac{1}{2}, for a large prime pp, making the expected number of guesses ~22.

After that, we can compute

x=(a+a2n)(p+1)/2 mod p\boxed{x = \left(a + \sqrt{a^2 - n} \right)^{(p + 1)/2}\ \mathrm{mod}\ p}

where a2n\sqrt{a^2 - n} is analagous to a complex number in an extended field.

Implementation#

Here is an implementation of the solution to the original problem in C++:

/* https://cp-algorithms.com/algebra/polynomial.html#calculating-functions-of-polynomial */
#include <bits/stdc++.h>
using namespace std;
#define ll long long

using poly = vector<ll>;
using f2p = array<ll, 2>;

const int mod = 998244353; // c * 2 ^ k + 1
const int root = 15311432; // PrimitiveRoot[mod] ^ c
const int root_1 = 469870224; // root ^ -1 % mod
const int root_pw = 1 << 23; // 2 ^ k
ll f2 = -1; // a^2 - n

ll mul(ll a, ll b) {
    return (a * b) % mod;
}

f2p mul(f2p a, f2p b) {
    return {
        (a[0] * b[0] % mod + a[1] * b[1] % mod * f2 % mod) % mod,
        (a[0] * b[1] % mod + a[1] * b[0] % mod) % mod
    };
}

template <typename T>
T bexp(T x, ll n, T identity) {
	assert(n >= 0);
    T res = identity; // identity element
	while(n > 0) {
		if(n & 1) res = mul(res, x);
        x = mul(x, x);
		n >>= 1;
	}
	return res;
}

ll inverse(ll x) {
    return bexp(x, mod - 2, 1LL);
}

void ntt(poly &a, bool invert) {
    int n = a.size();

    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;

        for (; j & bit; bit >>= 1)
            j ^= bit;

        j ^= bit;

        if (i < j)
            swap(a[i], a[j]);
    }

    for (int len = 2; len <= n; len <<= 1) {
        int wlen = invert ? root_1 : root;

        for (int i = len; i < root_pw; i <<= 1)
            wlen = (int)(1LL * wlen * wlen % mod);

        for (int i = 0; i < n; i += len) {
            int w = 1;

            for (int j = 0; j < len / 2; j++) {
                int u = a[i + j], v = (int)(1LL * a[i + j + len / 2] * w % mod);
                a[i + j] = u + v < mod ? u + v : u + v - mod;
                a[i + j + len / 2] = u - v >= 0 ? u - v : u - v + mod;
                w = (int)(1LL * w * wlen % mod);
            }
        }
    }

    if (invert) {
        int n_1 = inverse(n);

        for (ll &x : a)
            x = (int)(1LL * x * n_1 % mod);
    }
}

poly mul(poly a, poly b, int N) {
    if(a.size() == 1) {
        for(int i = 0; i < b.size(); i++) {
            b[i] = ((b[i] * a[0]) % mod + mod) % mod;
        }
        b.resize(N);
        return b;
    }
    poly fa(a.begin(), a.end()), fb(b.begin(), b.end());
    
    int n = 1;

    while (n < a.size() + b.size())
        n <<= 1;

    fa.resize(n);
    fb.resize(n);

    ntt(fa, false);
    ntt(fb, false);

    for (int i = 0; i < n; i++)
        fa[i] = (fa[i] * fb[i]) % mod;

    ntt(fa, true);

    poly result(n);

    for (int i = 0; i < n; i++)
        result[i] = fa[i];

    result.resize(N);
    return result;
}

// a + b
poly add(poly a, poly b, int N) {
    if (a.size() < b.size()) swap(a, b);
    for (int i = 0; i < b.size(); i++) {
        a[i] = (a[i] + b[i]) % mod;
    }

    a.resize(N);
    return a;
}

// integral of p
poly integral(poly p, int N) {
    poly r({0});
    for (int i = 0; i < p.size(); i++) {
        r.push_back((inverse(i + 1) * p[i]) % mod);
    }

    r.resize(N);
    return r;
}

// p'
poly derivative(poly p, int N) {
    poly r;
    for (int i = 1; i < p.size(); i++) {
        r.push_back((p[i] * i) % mod);
    }

    r.resize(N);
    return r;
}

// a^-1
poly inverse(poly a, int N) {
    poly q({inverse(a[0])}), short_a;
    int sz = 1;
    do {
        sz <<= 1;
        for (int i = short_a.size(); i < min(sz, (int) a.size()); i++) {
            short_a.push_back(a[i]);
        }
        q = mul(q, add({2}, mul({-1}, mul(short_a, q, sz), sz), sz), sz);
    } while (sz < N);

    q.resize(N);
    return q;
}

// ln a
poly logarithm(poly a, int N) {
    poly r = integral(mul(derivative(a, N), inverse(a, N), N), N);
    return r;
}

// e^p
poly exp(poly p, int N) {
    poly q({1}), short_p;
    int sz = 1;

    do {
        sz <<= 1;
        for (int i = short_p.size(); i < min(sz, (int) p.size()); i++) {
            short_p.push_back(p[i]);
        }

        q = mul(q, add({1}, add(short_p, mul({-1}, logarithm(q, sz), sz), sz), sz), sz);
    } while (sz < N);

    q.resize(N);
    return q;
}

// raise p^k
poly pow(poly p, int k, int N) {
    poly lnp = logarithm(p, N);
    return exp(mul({k}, logarithm(p, N), N), N);
}

// algorithms described in the blog
ll cipollas(ll n) {
    // check if is a valid quadratic residue
    if(bexp(n, mod >> 1, 1LL) != 1) return -1;
    ll a = -1;
    while(true) {
        // should take ~2 operations
        a = rand() % mod;
        f2 = (a * a % mod - n + mod) % mod;
        if(bexp(f2, mod >> 1, 1LL) != 1) break;
    }

    f2p cn = {a, 1};
    f2p cnp = bexp(cn, (mod + 1) >> 1, {1, 0});
    ll ans = (cnp[0] + mod) % mod;
    // check if is a valid solution
    if(ans * ans % mod == n) return min(ans, mod - ans);
    else return -1;
}

poly sqrt(poly p, int N) {
    ll f0 = p[0];
    ll res = cipollas(f0);
    assert(res != -1);
    poly q({res});
    int sz = 1;
    do {
        sz <<= 1;
        poly den = inverse(mul({2}, q, sz), sz);
        poly num = add(mul(q, q, sz), p, sz);
        q = mul(den, num, sz);
    } while(sz < N);
    q.resize(N);
    return q;
}

void solve() {
    int n;
    ll k;
    cin >> n >> k;
    poly p(n + 1);
    int N = n + 1;
    for (int i = 0; i <= n; i++) {
        cin >> p[i];
    }

    poly res1 = sqrt(p, N);
    poly res2 = inverse(res1, N);
    poly res3 = integral(res2, N);
    poly res4 = exp(res3, N);
    poly res5 = add({2 - p[0]}, add(p, mul({-1}, res4, N), N), N);
    poly res6 = logarithm(res5, N);
    poly res7 = add({1}, res6, N);
    poly res8 = pow(res7, k, N);
    poly res9 = derivative(res8, N);
    
    res9.resize(n);
    for (int i = 0; i < n; i++) {
        cout << res9[i] << ' ';
    }
    cout << '\n';
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    // freopen("io.out", "w", stdout);

    int T = 1;

    // cin >> T;
    while (T--)
        solve();
    return 0;
}