【线段树维护区间两两乘积和】2021 杭电多校第九场 1006 Guess the weight
2021-08-18 21:11:00 # ACM

题链

题目分析

假设拿到的第一张牌的费用是 i,牌库中费用 <i 的有 x 张,费用 >i 大的有 y 张;
如果 x<=y,那就需要猜 “greater”;
如果 x>y,那就需要猜 “smaller”;

开一个权值线段树ai 表示费用为 i 的有多少张,需要确定一个中值 mid,当 i<=mid 时猜 “greater”, i>mid 时猜 “smaller”;

考虑抽到第一张牌费用为 i 时,此时抽这张牌的概率 P=aisumsum 为总牌数:
  当 i<=mid 时,在剩下的牌中选择到费用 j>i 的概率是 j>iajsum1,则 P=aij>iajsum(sum1)
  当 i>mid  时,在剩下的牌中选择到费用 j<i 的概率是 j<iajsum1,则 P=aij<iajsum(sum1)

那么对于总概率 P=[i<=midaij>iaj+i>midaij<iaj]/[sum(sum1)]

对于该公式的 aiaj并不好维护,所以考虑相反的方式计算 P 值;

  当 i<=mid 时,在剩下的牌中选择费用 j<=i
  当 i>mid  时,在剩下的牌中选择费用 j>=i

于是计算的总概率 P=1[i<=midai[(j<=iaj)1]+i>midai[(j>=iaj)1]]/[sum(sum1)]
(这里的 1 是指选取第一张牌所以 1 );

其中公式化简后有 aiajai,就是需要维护区间和以及区间两两乘积和

如何维护

假设某根节点 rt 的左右孩子节点 ls,rs,已经维护好了他们的区间和 sum 以及区间两两乘积和 mulsum,那么根节点的信息如此:ls.sum+rs.sum,ls.mulsum+rs.mulsum+ls.sumrs.sum

就如 ls=a1+a2,a12+a22+a1a2rs=a3,a32
rt.mulsum=a12+a22+a32+a1a2+(a1+a2)(a3)

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#include <bits/stdc++.h>
using namespace std;
#define LL long long
#define Pair pair<LL ,LL >
#define ls rt<<1
#define rs rt<<1|1
#define PI acos(-1.0)
#define eps 1e-13
#define mod 998244353
#define MAXN 200001
#define MS 200005

int n,m;
struct node{
LL sum;
LL mulsum;
}p[MAXN<<2];

void build(int l,int r,int rt){
p[rt] = {0,0};
if(l == r) return;
int m = l+r>>1;
build(l,m,ls);
build(m+1,r,rs);
}

void push_up(int rt){
p[rt].sum = p[ls].sum + p[rs].sum;
p[rt].mulsum = p[ls].mulsum + p[rs].mulsum + p[ls].sum*p[rs].sum;
}

void modify(int pos,int l,int r,int rt,int val){
if(l == r){
p[rt].sum += val;
p[rt].mulsum = p[rt].sum*p[rt].sum;
return;
}
int m = l+r>>1;
if(m >= pos) modify(pos,l,m,ls,val);
else modify(pos,m+1,r,rs,val);
push_up(rt);
}

int get_mid(int l,int r,int rt,int kth){
if(l == r) return l;
int m = l+r>>1;
if(kth <= p[ls].sum) return get_mid(l,m,ls,kth);
else return get_mid(m+1,r,rs,kth-p[ls].sum);
}

Pair cal(Pair t1, Pair t2){
return {t1.first + t2.first
,t1.second + t2.second + t1.first*t2.first};
}

Pair query(int L,int R,int l,int r,int rt){
if(L <= l && r <= R){
return {p[rt].sum, p[rt].mulsum};
}
Pair ans = {0,0};
int m = l+r>>1;
if(m >= L) ans = cal(ans,query(L,R,l,m,ls));
if(m < R) ans = cal(ans,query(L,R,m+1,r,rs));
return ans;
}

inline void solve(){
cin >> n >> m; build(1,MAXN,1);
for(int i=1;i<=n;i++){
int x; cin >> x;
modify(x,1,MAXN,1,1);
}
while(m--){
int op,x,w;
cin >> op;
if(op == 1){
cin >> x >> w;
modify(w,1,MAXN,1,x);
}
else{
// get_mid
int mid = get_mid(1,MAXN,1,p[1].sum/2);
int lsum = query(1,mid,1,MAXN,1).first;
int val = query(mid,mid,1,MAXN,1).first;
if(lsum-val > p[1].sum-lsum) mid--;
// get_ans
Pair tl = query(1,mid,1,MAXN,1);
Pair tr = query(mid+1,MAXN,1,MAXN,1);
LL t1 = tl.second - tl.first;
LL t2 = tr.second - tr.first;
LL a = t1 + t2;
LL b = p[1].sum*(p[1].sum-1);
a = b-a;
LL t = __gcd(a,b);
a /= t, b /= t;
cout << a << "/" << b << "\n";
}
}
}

int main() {
ios::sync_with_stdio(false);
// srand((unsigned)time(0));
int ce = 1;
cin >> ce;
// scanf("%d",&ce);
while(ce--) solve();

return 0;
}
/*


*/
Prev
2021-08-18 21:11:00 # ACM
Next