解けず。$1000$点だからとF openしたので座ってるだけになった。

solution

回文$a$の周期性$d$で分類する。$O(d(N)^2)$。

回文$a$はちょうど$K^{\lceil \frac{N}{2} \rceil}$個存在する。 しかしこれを単に$N$倍しても答えにはならない。 何度か回転$f : a_0a_1a_2\dots a_{n-1} \mapsto a_1a_2\dots a_{n-1}a_0$すると他の回文と一致する場合がある。

ここで回文$a$をその周期で分類する。 回文の周期$d$とは、$f^d(a) = a$となるような最小の$d \ge 1$のこと。 ある周期$d$の回文の個数を(回転で一致する場合は同一視して)$\mathrm{num}(d)$とすると、 $\mathrm{ans} = \sum_{d | N} \mathrm{num(d)} \cdot d$となる。

$\mathrm{num}(d)$を求めよう。 単に$K^{\lceil \frac{d}{2} \rceil}$個だと$d’|d$な$d’ \lt d$を重複して数えること、$d$が偶数なら$\frac{d}{2}$回の回転で他の周期$d$の回文と衝突することから、$\mathrm{num}(d) = (K^{\lceil \frac{d}{2} \rceil} - \sum_{d’|d \land d’\lt d} \mathrm{num}(d’))\cdot \frac{1}{2 - d \bmod 2}$となる。

implementation

#include <iostream>
#include <vector>
#include <set>
#include <map>
#include <cmath>
#include <cassert>
#define repeat(i,n) for (int i = 0; (i) < (n); ++(i))
typedef long long ll;
using namespace std;

vector<int> sieve_of_eratosthenes(int n) { // enumerate primes in [2,n] with O(n log log n)
    vector<bool> is_prime(n+1, true);
    is_prime[0] = is_prime[1] = false;
    for (int i = 2; i*i <= n; ++i)
        if (is_prime[i])
            for (int k = i+i; k <= n; k += i)
                is_prime[k] = false;
    vector<int> primes;
    for (int i = 2; i <= n; ++i)
        if (is_prime[i])
            primes.push_back(i);
    return primes;
}
vector<ll> list_prime_factrors(ll n, vector<int> const & primes) {
    vector<ll> result;
    for (int p : primes) {
        if (n < p *(ll) p) break;
        while (n % p == 0) {
            result.push_back(p);
            n /= p;
        }
    }
    if (n != 1) result.push_back(n);
    return result;
}
ll powi(ll x, ll y, ll p) { // O(log y)
    assert (y >= 0);
    x = (x % p + p) % p;
    ll z = 1;
    for (ll i = 1; i <= y; i <<= 1) {
        if (y & i) z = z * x % p;
        x = x * x % p;
    }
    return z;
}

const ll mod = 1e9+7;
int main() {
    int n, k; cin >> n >> k;
    set<int> ds { 1 };
    for (ll p : list_prime_factrors(n, sieve_of_eratosthenes(sqrt(n) + 3))) {
        set<int> prev_ds = ds;
        for (int d : prev_ds) {
            ds.insert(d * p);
        }
    }
    ll ans = 0;
    map<int,ll> num;
    for (int d : ds) {
        ll acc = powi(k, (d+1)/2, mod);
        for (int d2 : ds) if (d % d2 == 0 and d2 < d) {
            acc -= num[d2];
        }
        num[d] = (acc % mod + mod) % mod;
        ans += num[d] * d / (d % 2 == 0 ? 2 : 1);
    }
    ans %= mod;
    cout << ans << endl;
    return 0;
}