AOJ 2255: 6/2(1+2)
solution
字句解析 + 括弧までの構文解析して区間DP。各<expr>
をそのあり得る値の集合へ写す。入力の長さを$N$、演算子の数を$K$として$O(N + 2^K)$。
区間DPの部分は以下でもできるようだが、実装量も速度も悪そうなので勧めない。
- 再帰 どこで分割するかを毎回全部試す (これをメモ化すれば先に言った区間DP)
- 演算子の優先順位を全列挙
next_permutation
とかで頑張る
implementation
#include <cassert>
#include <cctype>
#include <iostream>
#include <set>
#include <vector>
#define repeat(i, n) for (int i = 0; (i) < int(n); ++(i))
#define repeat_from(i, m, n) for (int i = (m); (i) < int(n); ++(i))
using namespace std;
template <typename X, typename T> auto vectors(X x, T a) { return vector<T>(x, a); }
template <typename X, typename Y, typename Z, typename... Zs> auto vectors(X x, Y y, Z z, Zs... zs) { auto cont = vectors(y, z, zs...); return vector<decltype(cont)>(x, cont); }
struct expr_t {
int immidiate;
vector<expr_t> value;
vector<char> op;
};
expr_t parse_value(string::const_iterator & first, string::const_iterator last) {
assert (first != last);
assert (isdigit(*first));
int acc = 0;
while (first != last and isdigit(*first)) {
acc = acc * 10 + (*first - '0');
++ first;
}
return (expr_t) { acc };
}
expr_t parse_expr(string::const_iterator & first, string::const_iterator last);
expr_t parse_term(string::const_iterator & first, string::const_iterator last) {
assert (first != last);
if (*first == '(') {
++ first;
expr_t expr = parse_expr(first, last);
assert (*first == ')');
++ first;
return expr;
} else {
return parse_value(first, last);
}
}
expr_t parse_expr(string::const_iterator & first, string::const_iterator last) {
assert (first != last);
expr_t expr = { -1 };
expr.value.push_back(parse_term(first, last));
while (first != last and *first != ')') {
assert (*first == '+' or *first == '-' or *first == '*' or *first == '/');
expr.op.push_back(*first);
++ first;
expr.value.push_back(parse_term(first, last));
}
return expr;
}
expr_t parse(string const & s) {
auto first = s.begin();
expr_t expr = parse_expr(first, s.end());
assert (first == s.end());
return expr;
}
set<int> solve(expr_t const & expr) {
if (expr.immidiate != -1) {
return set<int>({ expr.immidiate });
} else {
int n = expr.value.size();
auto dp = vectors(n, n + 1, set<int>());
repeat (i, n) {
dp[i][i + 1] = solve(expr.value[i]);
}
repeat_from (len, 2, n + 1) {
repeat_from (r, len, n + 1) {
int l = r - len;
repeat_from (m, l, r - 1) {
set<int> & left = dp[l][m + 1];
set<int> & right = dp[m + 1][r];
char op = expr.op[m];
if (op == '+') {
for (int a : left) for (int b : right) dp[l][r].insert(a + b);
} else if (op == '-') {
for (int a : left) for (int b : right) dp[l][r].insert(a - b);
} else if (op == '*') {
for (int a : left) for (int b : right) dp[l][r].insert(a * b);
} else if (op == '/') {
for (int b : right) if (b != 0) for (int a : left) dp[l][r].insert(a / b);
}
}
}
}
return dp[0][n];
}
}
int main() {
while (true) {
string s; cin >> s;
if (s == "#") break;
expr_t expr = parse(s);
int result = solve(expr).size();
printf("%d\n", result);
}
return 0;
}