「topological sortの個数数えればいいのか綺麗だなあ」までは辿り着いたが「でも数えるのは無理なんだよな」で刈ってしまった。 一般には実際にNPよりも難しいらしい12

solution

topological sortの数を数えるDP。$O(N(N + K))$。

次の規則での並べ方を数えればよい。

  • $0$を$N$個、$1, \dots, N$をそれぞれ$K - 1$個
  • $i \in \{ 1, \dots, N \}$ は $i$番目(1-based)の$0$ より後ろに出現する
  • $i \in \{ 2, \dots, N \}$ は $1$番目の$i - 1$ より後ろに出現する

「$A$が$B$より後ろに出現する」の形の制約は、制約をDAGとして書いたグラフのtopological sortとして表現できる。 なのでこのDAGのtopological sortの数を数えればよい。 愚直に見れば$O(K^N)$かかる。 しかし$2$番目以降の$i$の出現は同じ$i$以外に対して独立なので、残りの頂点の数を$R$として${}_{R}C_{K-2}$を掛けることでまとめて処理できる。 これで全体で$O(2^N)$に落ちる。

まとめ

  • 「$A$が$B$より後ろに出現する」はtopological sort
  • topological sortの数を数えるのは難しいが、グラフを非連結にできれば容易

implementation

#include <cassert>
#include <iostream>
#include <vector>
#define REP(i, n) for (int i = 0; (i) < int(n); ++ (i))
#define REP3(i, m, n) for (int i = (m); (i) < int(n); ++ (i))
using ll = long long;
using namespace std;
template <typename X, typename T> auto vectors(X x, T a) { return vector<T>(x, a); }
template <typename X, typename Y, typename Z, typename... Zs> auto vectors(X x, Y y, Z z, Zs... zs) { auto cont = vectors(y, z, zs...); return vector<decltype(cont)>(x, cont); }

ll powmod(ll x, ll y, ll m) {
    assert (0 <= x and x < m);
    assert (0 <= y);
    ll z = 1;
    for (ll i = 1; i <= y; i <<= 1) {
        if (y & i) z = z * x % m;
        x = x * x % m;
    }
    return z;
}
ll modinv(ll x, ll p) {
    assert (x % p != 0);
    return powmod(x, p - 2, p);
}
template <int32_t MOD>
int32_t fact(int n) {
    static vector<int32_t> memo(1, 1);
    while (n >= memo.size()) {
        memo.push_back(memo.back() *(int64_t) memo.size() % MOD);
    }
    return memo[n];
}
template <int32_t PRIME>
int32_t inv_fact(int n) {
    static vector<int32_t> memo(1, 1);
    while (n >= memo.size()) {
        memo.push_back(memo.back() *(int64_t) modinv(memo.size(), PRIME) % PRIME);
    }
    return memo[n];
}
template <int MOD>
int choose(int n, int r) {
    assert (0 <= r and r <= n);
    return fact<MOD>(n) *(ll) inv_fact<MOD>(n - r) % MOD *(ll) inv_fact<MOD>(r) % MOD;
}

constexpr int mod = 1e9 + 7;
int solve(int n, int k) {
    if (k == 1) return 1;
    auto dp = vectors(n + 1, n + 1, int());
    dp[0][0] = 1;
    REP3 (i, 1, n + 1) {
        REP (j, i + 1) {
            if (i - 1 >= 0) {
                dp[i][j] += dp[i - 1][j];
            }
            if (j - 1 >= 0) {
                int remaining = (n - i) + (k - 1) * (n - (j - 1));
                dp[i][j] += dp[i][j - 1] *(ll) choose<mod>(remaining - 1, k - 2) % mod;
            }
            if (dp[i][j] >= mod) {
                dp[i][j] -= mod;
            }
        }
    }
    return dp[n][n] *(ll) fact<mod>(n) % mod;
}

int main() {
    int n, k; cin >> n >> k;
    cout << solve(n, k) << endl;
    return 0;
}