1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
| #include <bits/stdc++.h>
using namespace std; #define gg(x) cout << #x << ": " << x << "\n"; #define LL long long #define ULL unsigned long long #define Pair pair<int ,int > #define ls rt<<1 #define rs rt<<1|1 #define PI acos(-1.0) #define eps 1e-8 #define fi first #define se second #define ll long long const int mod = 998244353; const int MAXN = 2e9; const int MS = 100009;
int n,m,k; vector<int > vc[MS]; int w[MS], sz[MS], zson[MS], dep[MS]; struct node{ int l,r; int sum; }p[MS<<5]; int root[MS], tot; LL ans;
void modify(int pos, int l, int r, int &rt, int val){ if(!rt) rt = ++tot; p[rt].sum += val; if(l == r) return; int m = l+r>>1; if(m >= pos) modify(pos,l,m,p[rt].l,val); else modify(pos,m+1,r,p[rt].r,val); }
int query(int L, int R, int l, int r, int &rt){ if(!rt) return 0; if(L <= l && r <= R) return p[rt].sum; int m = l+r>>1; int cc = 0; if(m >= L) cc += query(L,R,l,m,p[rt].l); if(m < R) cc += query(L,R,m+1,r,p[rt].r); return cc; }
void dfs(int u, int f){ sz[u] = 1; dep[u] = dep[f]+1; for(auto v:vc[u]){ dfs(v, u); sz[u] += sz[v]; if(sz[v] > sz[zson[u]]) zson[u] = v; } }
void add(int u, int val){ modify(dep[u],1,n,root[w[u]],val); for(auto v:vc[u]) add(v, val); }
void cal(int u, int lca){ int tar = w[lca]*2 - w[u]; int depl = dep[lca]+1; int depr = dep[lca]+(k-(dep[u]-dep[lca])); depr = min(depr, n); if(depl <= depr && 0 <= tar && tar <= n) ans += query(depl,depr,1,n,root[tar]); for(auto v:vc[u]) cal(v, lca); }
void dsu(int u, int op){ for(auto v:vc[u]){ if(v != zson[u]) dsu(v, 0); } if(zson[u]) dsu(zson[u], 1); for(auto v:vc[u]){ if(v != zson[u]) cal(v, u), add(v, 1); } modify(dep[u],1,n,root[w[u]],1); if(!op) add(u, -1); }
void solve(){ cin >> n >> k; for(int i=1;i<=n;i++) cin >> w[i]; for(int i=2;i<=n;i++){ int x; cin >> x; vc[x].push_back(i); } dfs(1, 0); dsu(1, 1); cout << ans*2 << "\n"; }
int main() { ios::sync_with_stdio(false); int ce = 1;
while(ce--) solve();
return 0; }
|