2021 ICPC 沈阳站 【L Perfect Matchings】 树上背包(树形dp)+容斥原理

2021 ICPC 沈阳站 【L Perfect Matchings】

树上背包(树形DP)+容斥原理

2021 ICPC 沈阳

img
img
题意:
  就是给你一个$2\times n$个点的完全图,从这个图里面抽出$2\times n - 1$条边,这些边形成一颗树,现在问你剩下的图里面点进行完美匹配有多少种方案?
  完美匹配方案可以理解为,对于一个$2\times n$个结点的图,找一个包含n条边的边集,由于每条边有两个端点,如果这个边集包含的点有$2 \times n$个,则是完全匹配(边集内任意两边没有公共端点)。
分析:
  先求不删边的情况下有多少种,之后减去边集里包含了被删除的边的个数。
  不删时,共有$C(2n, n) \times n! / 2^{n}$种($C(2n, n)$表示先选$n$条边的一个端点,$n!$表示剩下的$n$个点与之前选的$n$个点的匹配方式,除掉的是重复计算的,对于边$(x,y)$和$(y,x)$是相同的,而一共有$n$条边,可以理解为每条边交换还是不交换)。
  对于选择了$x$条来自被删除了的树上的边,剩下的$n - x$条边的选法有$C(2n-2x,n-x) \times (n - x)! / 2 ^ {n-x}$,从树上选$x$条满足条件的有多少种选法可以利用树上背包求解(树形dp),之后根据容斥原理减掉即可。
  树上背包:$dp[i][j][0/1]$表示以第$i$个点为根的子树,选择了$j$条符合条件的边,且$i$节点所在的边选不选的方案数。
AC代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#include <bits/stdc++.h>

using namespace std;
typedef long long ll;

const int maxn = 4005;
const ll mod = 998244353;
int n, x, y;
vector<int> vt[maxn];
ll dp[maxn][maxn][5];
int sum[maxn];
ll tmp[maxn][2];
ll fac[maxn], inv[maxn];//fac[i]是i!,inv[i]是2^(-i)的逆元
ll num, cnt;

ll qim(ll a, ll b)
{
a %= mod;
ll res = 1;
while(b)
{
if(b & 1) res = a * res % mod;
a = a * a % mod;
b >>= 1;
}
return res % mod;
}

void init()
{
fac[0] = 1;
for(int i = 1; i < maxn; ++ i) fac[i] = 1ll * fac[i-1] * i % mod;
inv[0] = 1;
inv[1] = qim(2,mod-2);
for(int i = 2; i < maxn; ++ i) inv[i] = 1ll * inv[i-1] * inv[1] % mod;
}

ll C(int a, int b)
{
if (a < b) return 0;
return 1ll * fac[a] * qim(fac[b],mod-2) % mod * qim(fac[a-b],mod-2)%mod;
}

void dfs(int u, int fa)
{
dp[u][0][0] = 1;
sum[u] = 1;
int p;
for(int k = 0; k < vt[u].size(); ++k)
{
p = vt[u][k];
if(p == fa) continue;
dfs(p, u);
memset(tmp, 0, sizeof(tmp));//辅助数组
for(int i = 0; i <= sum[u] / 2; ++i)
{
for(int j = 0; j <= sum[p] / 2; ++j)//枚举从当前这个子树中取多少个
{
tmp[i + j][0] = (tmp[i + j][0] + dp[u][i][0] * (dp[p][j][0] + dp[p][j][1]) % mod) % mod;
tmp[i + j][1] = (tmp[i + j][1] + dp[u][i][1] * (dp[p][j][0] + dp[p][j][1]) % mod) % mod;
tmp[i + j + 1][1] = (tmp[i + j + 1][1] + dp[u][i][0] * dp[p][j][0] % mod) % mod;
}
}

for(int i = 0; i <= sum[u] / 2 + sum[p] / 2 + 1; ++i)
{
dp[u][i][0] = tmp[i][0];
dp[u][i][1] = tmp[i][1];
}

sum[u] += sum[p];
}
}

int main()
{
scanf("%d", &n);

for(int i = 1; i <= 2 * n - 1; ++i)
{
scanf("%d%d", &x, &y);
vt[x].push_back(y);
vt[y].push_back(x);
}

init();

dfs(1, 0);
//ll ans = ((C(2 * n, n) * fac[n]) % mod * inv[n]) % mod;//计算不删情况下的种类数,i从0开始循环的话就不需要计算了
ll ans = 0;

//i从0开始,i=0时,即从树中选了0条边,就是全部的取法
for(int i = 0; i <= n; ++i)
{
cnt = (dp[1][i][0] + dp[1][i][1]) % mod;
num = ((C(2 * n - 2 * i, n - i) * fac[n - i]) % mod * inv[n - i]) % mod;
num = (cnt * num) % mod;
if(i&1) ans = ((ans - num) % mod + mod) % mod;
else ans = (ans + num) % mod;
}

printf("%lld\n", ans % mod);

return 0;
}