HackerRank World Codesprint: Alien Flowers
Alien Flowers
問題
赤色と青色の玉を一列に並べる。 このとき、赤赤、赤青、青青、青赤、という並びの数がそれぞれ$A,B,C,D$となるような並べ方はいくつあるか。$10^9+7$で割った余りを答えよ。
解説
組合せの数${}_nC_r$を計算する問題に帰着する。
並びは$RRR\dots RBBB\dots BRRR\dots RBBB\dots B\dots$となる。 これは$RRR\dots R, BBB\dots B, RRR\dots R, BBB\dots B, \dots$というグループに分けることができる。このグループの数もそれぞれの色の玉の数も固定できるので、グループへの玉の分配の方法を計算すればよい。
使う玉の個数は$1+A+B+C+D$である。 列が赤色で始まるとき、赤色の個数は$1+A+D$で、青色の個数は$C+D$である。青で始まる場合も同様。 $B,D$は色の切り替わりの回数なので$|B - D| \le 1$である。そうでない場合は個数は$0$となる。 $B \lt D \; (= B+1)$であるとき、赤青という並びより青赤という並びの方が多いので、赤色で始まって赤色で終わる。 赤で始まるとき赤色のグループの数は$1 + D$であり、青色のグループの数は$B$である。 $n$個のグループに$r$個の玉を分配するとき、それぞれのグループはひとつ以上の玉を含むので、そのような組合せの数は${}_{(n-r)+r-1}C_{r-1}$となる。
実装
組合せは$O(n^2)$で計算すると間に合わない。
#include <iostream>
#include <vector>
#include <cassert>
#define repeat(i,n) for (int i = 0; (i) < (n); ++(i))
#define repeat_from(i,m,n) for (int i = (m); (i) < (n); ++(i))
typedef long long ll;
using namespace std;
const int mod = 1e9+7;
ll inv(ll x) { // O(logn)
assert (0 < x and x < mod);
ll y = 1;
for (int i = 0; (1 << i) <= mod - 2; ++ i) {
if ((mod - 2) & (1 << i)) {
y = y * x % mod;
}
x = x * x % mod;
}
return y;
}
ll combination(ll n, ll r) { // O(nlogn), O(1)
assert (0 <= n and 0 <= r or r <= n);
static vector<ll> fact(1,1);
static vector<ll> ifact(1,1);
if (not (n < fact.size())) {
int l = fact.size();
fact.resize(n + 1);
ifact.resize(n + 1);
repeat_from (i,l,n+1) {
fact[i] = fact[i-1] * i % mod;
ifact[i] = inv(fact[i]);
}
}
r = min(r, n - r);
return fact[n] * ifact[n-r] % mod * ifact[r] % mod;
}
ll distribute_strict(int n, int r) { // distribute n same things into r distinguishable groups, each group has positive number of things
assert (n >= r);
return combination((n-r)+r-1,r-1);
}
ll starts_with_r(int rr, int rb, int bb, int br) {
if (not (rb == br or rb == br+1)) return 0;
if (rb == 0 and br == 0) return bb == 0 ? 1 : 0;
ll r = distribute_strict(1 + rr + br, 1 + br);
ll b = distribute_strict( bb + rb, rb);
return r * b % mod;
}
int main() {
int a, b, c, d; cin >> a >> b >> c >> d;
cout << (starts_with_r(a,b,c,d) + starts_with_r(c,d,a,b)) % mod << endl;
}