1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
| #include <bits/stdc++.h>
using namespace std; typedef long long ll;
const int maxn = 4005; const ll mod = 998244353; int n, x, y; vector<int> vt[maxn]; ll dp[maxn][maxn][5]; int sum[maxn]; ll tmp[maxn][2]; ll fac[maxn], inv[maxn]; ll num, cnt;
ll qim(ll a, ll b) { a %= mod; ll res = 1; while(b) { if(b & 1) res = a * res % mod; a = a * a % mod; b >>= 1; } return res % mod; }
void init() { fac[0] = 1; for(int i = 1; i < maxn; ++ i) fac[i] = 1ll * fac[i-1] * i % mod; inv[0] = 1; inv[1] = qim(2,mod-2); for(int i = 2; i < maxn; ++ i) inv[i] = 1ll * inv[i-1] * inv[1] % mod; }
ll C(int a, int b) { if (a < b) return 0; return 1ll * fac[a] * qim(fac[b],mod-2) % mod * qim(fac[a-b],mod-2)%mod; }
void dfs(int u, int fa) { dp[u][0][0] = 1; sum[u] = 1; int p; for(int k = 0; k < vt[u].size(); ++k) { p = vt[u][k]; if(p == fa) continue; dfs(p, u); memset(tmp, 0, sizeof(tmp)); for(int i = 0; i <= sum[u] / 2; ++i) { for(int j = 0; j <= sum[p] / 2; ++j) { tmp[i + j][0] = (tmp[i + j][0] + dp[u][i][0] * (dp[p][j][0] + dp[p][j][1]) % mod) % mod; tmp[i + j][1] = (tmp[i + j][1] + dp[u][i][1] * (dp[p][j][0] + dp[p][j][1]) % mod) % mod; tmp[i + j + 1][1] = (tmp[i + j + 1][1] + dp[u][i][0] * dp[p][j][0] % mod) % mod; } }
for(int i = 0; i <= sum[u] / 2 + sum[p] / 2 + 1; ++i) { dp[u][i][0] = tmp[i][0]; dp[u][i][1] = tmp[i][1]; }
sum[u] += sum[p]; } }
int main() { scanf("%d", &n);
for(int i = 1; i <= 2 * n - 1; ++i) { scanf("%d%d", &x, &y); vt[x].push_back(y); vt[y].push_back(x); }
init();
dfs(1, 0); ll ans = 0;
for(int i = 0; i <= n; ++i) { cnt = (dp[1][i][0] + dp[1][i][1]) % mod; num = ((C(2 * n - 2 * i, n - i) * fac[n - i]) % mod * inv[n - i]) % mod; num = (cnt * num) % mod; if(i&1) ans = ((ans - num) % mod + mod) % mod; else ans = (ans + num) % mod; }
printf("%lld\n", ans % mod);
return 0; }
|