解法

概要

上手な不変条件を見付けるとそれが必要十分になってるやつ。 $O(|S|)$。

詳細

次を確認する:

  • 操作後の文字列には常に同じ文字が連続する部分がある
  • $\phi(a) = 0, \phi(b) = 1, \phi(c) = 2$ とおいて $\sum \phi(S_i) \equiv m \pmod{3}$ を考えたときこの値 $m \equiv \phi(S)$ は操作によって不変

これは構築できる文字列の必要条件になっている。

いくつか例外を除いて、これは十分条件にもなっている。 つまり、次の条件を満たす文字列 $T$ の個数を数えればよい。 これは $|S|, m$ だけに依存する単純なDP。

  • 同じ文字が連続する部分がある
  • 入力 $S$ と同じ $m$ の値を持つ

加えて $S$ に同じ文字が連続する部分がない場合は $T = S$ の分の $1$ を足す。

ただし次は例外なので別に処理する。

  • $S$ がすべて同じ文字から成る場合
  • $|S| = 2, 3$ の場合

メモ

実装

#include <algorithm>
#include <array>
#include <functional>
#include <iostream>
#include <numeric>
#include <string>
#include <unordered_set>
#include <vector>
#define REP(i, n) for (int i = 0; (i) < (int)(n); ++ (i))
#define ALL(x) begin(x), end(x)
using ll = long long;
using namespace std;
template <typename X, typename T> auto vectors(X x, T a) { return vector<T>(x, a); }
template <typename X, typename Y, typename Z, typename... Zs> auto vectors(X x, Y y, Z z, Zs... zs) { auto cont = vectors(y, z, zs...); return vector<decltype(cont)>(x, cont); }

template <int32_t MOD>
struct mint {
    int64_t value;
    mint() = default;
    mint(int64_t value_) : value(value_) {}
    inline mint<MOD> operator + (mint<MOD> other) const { int64_t c = this->value + other.value; return mint<MOD>(c >= MOD ? c - MOD : c); }
    inline mint<MOD> & operator += (mint<MOD> other) { this->value += other.value; if (this->value >= MOD) this->value -= MOD; return *this; }
};


constexpr int MOD = 998244353;

mint<MOD> solve1(string s) {
    unordered_set<string> used;
    function<void (void)> go = [&]() {
        if (used.count(s)) return;
        used.insert(s);
        REP (i, s.length() - 1) {
            if (s[i] != s[i + 1]) {
                char x = s[i];
                char y = s[i + 1];
                s[i] = s[i + 1] = 'a' ^ 'b' ^ 'c' ^ x ^ y;
                go();
                s[i] = x;
                s[i + 1] = y;
            }
        }
    };
    go();
    return used.size();
}

mint<MOD> solve2(int n, int m, bool is_all_different) {
    auto dp = vectors(n, array<array<mint<MOD>, 4>, 3>());
    dp[0][0][0] = 1;  // "a"
    dp[0][1][1] = 1;  // "b"
    dp[0][2][2] = 1;  // "c"
    REP (i, n - 1) {
        REP (j, 3) {
            REP (last_c, 4) {
                REP (c, 3) {
                    int next_c = (last_c == 3 or (last_c == c) ? 3 : c);
                    dp[i + 1][(j + c) % 3][next_c] += dp[i][j][last_c];  // add c
                }
            }
        }
    }
    return dp[n - 1][m][3] + is_all_different;
}

mint<MOD> solve(string const & s) {
    int n = s.length();
    if (n < 4) return solve1(s);

    bool is_all_same = (count(ALL(s), s.front()) == n);
    if (is_all_same) return 1;

    bool is_all_different = true;
    REP (i, n - 1) {
        if (s[i] == s[i + 1]) {
            is_all_different = false;
        }
    }

    int m = (accumulate(ALL(s), 0) - n * 'a') % 3;
    return solve2(n, m, is_all_different);
}

int main() {
    string s; cin >> s;
    cout << solve(s).value << endl;
    return 0;
}