算法竞赛2021 ICPC Southeastern Europe Regional Contest_Werewolves
//#include "stdafx.h"
#include <cstdio>
#include <cstring>
#include <string>
#include <cmath>
#include <iostream>
#include <algorithm>
#include <queue>
#include <cstdlib>
#include <vector>
using namespace std;
typedef long long ll;
const int MAXN = 3000 + 10;
const int mod = 998244353;
vector<int> g[MAXN];
int n, m;
int c[MAXN]={0,1,1,3,3};
//int c[MAXN]={0,2,3,3};
char vis[MAXN];
int dp1[MAXN][MAXN], dp2[MAXN][MAXN], dp3[MAXN];
ll tmp1[MAXN], tmp2[MAXN], tmp3;
int res=0;
int edge[610]={1, 2, 1, 3,1,4};
//int edge[610]={1, 2, 2, 3};
int dfs(int u, int start, int ci) {
//printf("df u=%d, start=%d\n",u,start);
int p=1;
if(c[u] == c[ci]) dp1[u][1] = 1;
else dp2[u][1] = 1;
for(auto v : g[u])
{
if(v == start) continue;
int sz = dfs(v, u,ci);
//printf("sz = %d, %d %d\n",sz,v,u);
tmp3 = dp3[u];
for(int i = 1; i <= min(p, m); i++) {
tmp1[i] = dp1[u][i];
tmp2[i] = dp2[u][i];
}
dp3[u] = (dp3[u] + tmp3 * dp3[v]) % mod;
for(int j = 1; j <= min(sz,m); j++) {
dp1[u][j] = (dp1[u][j] + tmp3 * dp1[v][j]) % mod;
dp2[u][j] = (dp2[u][j] + tmp3 * dp2[v][j]) % mod;
}
for(int i = 1; i <= min(p, m); i++) {
dp1[u][i] = (dp1[u][i] + tmp1[i] * dp3[v]) % mod;
dp2[u][i] = (dp2[u][i] + tmp2 [i] * dp3[v]) % mod;
for(int j = 1; j <= min(sz,m); j++) {
if(i+j <= m)
{
dp1[u][i+j] = (dp1[u][i+j] + tmp1[i] * dp1[v][j]) % mod;
dp2[u][i+j] = (dp2[u][i+j] + tmp2[i] * dp2[v][j]) % mod;
}
if(i>j ) {
dp1[u][i-j] = (dp1[u][i-j] + tmp1[i] * dp2[v][j]) % mod;
dp2[u][i-j] = (dp2[u][i-j] + tmp2[i] * dp1[v][j]) % mod;
}
if(j>i ) {
dp1[u][j-i] = (dp1[u][j-i] + tmp2[i] * dp1[v][j]) % mod;
dp2[u][j-i] = (dp2[u][j-i] + tmp1[i] * dp2[v][j]) % mod;
}
if(i == j) {
dp3[u] = (dp3[u] + tmp1[i] * dp2[v][j] + tmp2[i] * dp1[v][j]) % mod;
}
}
}
p += sz;
}
for(int i = 1; i <= min(p, m); i++) {
res = (res + dp1[u][i]) % mod;
//printf("dfs %d, dp1[u][i]=%d, u=%d, i=%d, p=%d, m=%d\n",res, dp1[u][i],u,i,p,m );
}
return p;
}
void solve() {
/*/
n=4;
for(int i = 1; i < n; i++) {
g[edge[2*(i-1)]].push_back(edge[2*(i-1)+1]);
g[edge[2*(i-1)+1]].push_back(edge[2*(i-1)]);
}
//*/
for(int i = 1; i <= n; i++)
{
if(vis[c[i]]) continue;
vis[c[i]] = 1; m = 0;
for(int j = 1; j <= n; j++) if(c[j] == c[i]) m++;
for(int j = 1; j <= n; j++)
{
for(int k = 0; k <= m; k++) {
dp1[j][k] = dp2[j][k] = dp3[j] = 0;
}
}
dfs(1, 0, i);
}
printf("%d\n",res);
}
int main() {
//*
scanf("%d",&n);
for(int i = 1; i <= n; i++) scanf("%d", &c[i]);
int u, v;
for(int i = 1; i < n; i++) {
scanf("%d", &u); scanf("%d", &v);
g[u].push_back(v);
g[v].push_back(u);
}//*/
solve();
return 0;
}