This documentation is automatically generated by online-judge-tools/verification-helper
View the Project on GitHub kmyk/competitive-programming-library
#include "modulus/modsqrt.hpp"
#pragma once
#include <cassert>
#include <cstdint>
#include "../modulus/modpow.hpp"
/**
* @brief find the sqrt $b$ of $a$ modulo $p$
* @param p must be a prime
* @note i.e. $b^2 \equiv a \pmod{p}$
* @note -1 if not found
*/
int modsqrt(int a, int p) {
auto legendre_symbol = [&](int a) {
return modpow(a, (p - 1) / 2, p); // Euler's criterion
};
a %= p;
if (a == 0) return 0;
if (p == 2) return a;
assert (p >= 3);
if (legendre_symbol(a) != +1) return -1;
int b = 1;
while (legendre_symbol(b) == 1) {
b += 1;
}
int e = 0;
int m = p - 1;
while (m % 2 == 0) {
m /= 2;
e += 1;
}
int64_t x = modpow(a, (m - 1) / 2, p);
int64_t y = a * x % p * x % p;
x = x * a % p;
int64_t z = modpow(b, m, p);
while (y != 1) {
int j = 0;
for (int64_t t = y; t != 1; t = t * t % p) ++ j;
if (e <= j) return -1;
z = modpow(z, 1ll << (e - j - 1), p);
x = x * z % p;
z = z * z % p;
y = y * z % p;
e = j;
}
assert (x * x % p == a);
return x;
}
#line 2 "modulus/modsqrt.hpp"
#include <cassert>
#include <cstdint>
#line 4 "modulus/modpow.hpp"
inline int32_t modpow(uint_fast64_t x, uint64_t k, int32_t MOD) {
assert (/* 0 <= x and */ x < (uint_fast64_t)MOD);
uint_fast64_t y = 1;
for (; k; k >>= 1) {
if (k & 1) (y *= x) %= MOD;
(x *= x) %= MOD;
}
assert (/* 0 <= y and */ y < (uint_fast64_t)MOD);
return y;
}
#line 5 "modulus/modsqrt.hpp"
/**
* @brief find the sqrt $b$ of $a$ modulo $p$
* @param p must be a prime
* @note i.e. $b^2 \equiv a \pmod{p}$
* @note -1 if not found
*/
int modsqrt(int a, int p) {
auto legendre_symbol = [&](int a) {
return modpow(a, (p - 1) / 2, p); // Euler's criterion
};
a %= p;
if (a == 0) return 0;
if (p == 2) return a;
assert (p >= 3);
if (legendre_symbol(a) != +1) return -1;
int b = 1;
while (legendre_symbol(b) == 1) {
b += 1;
}
int e = 0;
int m = p - 1;
while (m % 2 == 0) {
m /= 2;
e += 1;
}
int64_t x = modpow(a, (m - 1) / 2, p);
int64_t y = a * x % p * x % p;
x = x * a % p;
int64_t z = modpow(b, m, p);
while (y != 1) {
int j = 0;
for (int64_t t = y; t != 1; t = t * t % p) ++ j;
if (e <= j) return -1;
z = modpow(z, 1ll << (e - j - 1), p);
x = x * z % p;
z = z * z % p;
y = y * z % p;
e = j;
}
assert (x * x % p == a);
return x;
}