1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
| #include <bits/stdc++.h>
using namespace std; #define gg(x) cout << #x << ": " << x << "\n"; #define LL long long #define ULL unsigned long long #define Pair pair<int ,int > #define ls rt<<1 #define rs rt<<1|1 #define PI acos(-1.0) #define eps 1e-8 #define fi first #define se second #define ll long long
const int mod = 998244353; const int MAXN = 2e9; const int MS = 100009;
int n,m,k; int a[MS]; vector<int > vc[MS]; int sz[MS], zson[MS]; int S; int cnt[MS], tot; int ac[MS];
void dfs(int u, int f){ sz[u] = 1; for(auto v:vc[u]){ if(v != f){ dfs(v, u); sz[u] += sz[v]; if(sz[v] > sz[zson[u]]) zson[u] = v; } } }
void cal(int u, int f, int val){ cnt[a[u]] += val; if(val == 1 && cnt[a[u]] == 1) ++tot; else if(val == -1 && cnt[a[u]] == 0) --tot; for(auto v:vc[u]){ if(v != f && v != S){ cal(v, u, val); } } }
void dsu(int u, int f, int w){ for(auto v:vc[u]){ if(v != f && v != zson[u]){ dsu(v, u, 0); } } if(zson[u]) dsu(zson[u], u, 1), S = zson[u]; cal(u, f, 1), S = 0; ac[u] = tot; if(!w) cal(u, f, -1); }
void solve(){ cin >> n; for(int i=2;i<=n;i++){ int u,v; cin >> u >> v; vc[u].push_back(v); vc[v].push_back(u); } for(int i=1;i<=n;i++) cin >> a[i]; dfs(1, 0); dsu(1, 0, 1);
cin >> m; while(m--){ int x; cin >> x; cout << ac[x] << "\n"; } }
int main() { ios::sync_with_stdio(false); int ce = 1;
while(ce--) solve();
return 0; }
|