E2. PermuTree (hard version)
提示一:数值域无关性,不管给一个子树分配哪些数字,在子树上经过排序后可得到唯一的最大值
提示二:只关注当前节点的数值应当将哪些子树分隔开
根据根节点的值可以将子树划分为两个子集,设节点个数为x,child-x;当前节点得分为x*(child-x)
式子中当x越接近child/2值越大,所以我们可以考虑跑一个背包尽量使子集1占满child/2个节点
由于节点数目过多,背包容量过大,根据分析子树大小最多出现<=2000种数值,那么一定有大量重复,可以考虑多重背包
最后注意一个剪枝:直链带一个侧枝的情况
/*input
4
1 1 1
*/
#include<iostream>
#include<vector>
#include<algorithm>
using namespace std;
const int N=1e6+10;
int sum[N];
std::vector<std::vector<int>> ChildList(N);
std::vector<std::vector<int>> SumList(N);
void SumBinary(int n){
for(int i=1;i<=n;i++){
std::vector<int> tmp;
swap(tmp,SumList[i]);
sort(tmp.begin(),tmp.end());
int sz=tmp.size();
for(int k=0;k<sz;k++){
int j=k;
while(j<sz-1&&tmp[j+1]==tmp[k]) j++;
int sum_k=j-k+1,Bin=1;
while(sum_k>=Bin){
sum_k-=Bin;
SumList[i].emplace_back(Bin*tmp[k]);
Bin<<=1;
}
Bin=1;
while(Bin<=sum_k){
if(Bin&sum_k) SumList[i]
.emplace_back(Bin*tmp[k]);
Bin<<=1;
}
k=j;
}
}
}
void dfs(int root){
sum[root]=1;
int sz=ChildList[root].size();
for(int i=0;i<sz;i++){
int node=ChildList[root][i];
dfs(node);
sum[root]+=sum[node];
SumList[root].emplace_back(sum[node]);
}
}
long long Bag(int n){
long long ans=0;
vector<int> dp(n,0);
for(int i=1;i<=n;i++){
int BagSz=sum[i]/2;
**剪枝在这**
if(BagSz&&SumList[i].back()>=BagSz){
ans+=(sum[i]-1LL-SumList[i].back())*
SumList[i].back();
continue;
}
int sz=SumList[i].size();
for(int i=1;i<=BagSz;i++)
dp[i]=0;
for(int j=0;j<sz;j++)
for(int k=BagSz;k>=SumList[i][j];k--)
dp[k]=max(dp[k],dp[k-SumList[i][j]]
+SumList[i][j]);
ans+=(sum[i]-1LL-dp[BagSz])*dp[BagSz];
}
return ans;
}
int read(){
int ans=0;
char c=getchar();
while(c<'0'||c>'9') c=getchar();
while(c>='0'&&c<='9'){
ans=ans*10+c-'0';
c=getchar();
}
return ans;
}
void solve(){
int n;
long long ans=0;
cin>>n;
for(int i=2;i<=n;i++){
int x;cin>>x;
ChildList[x].emplace_back(i);
}
dfs(1);
SumBinary(n);
ans=Bag(n);
cout<<ans<<endl;
}
int main(){
cin.tie(0);
ios::sync_with_stdio(false);
int t;
t=1;
// cin>>t;
while(t--)
solve();
}
位或版本
/*input
4
1 1 1
*/
#include<iostream>
#include<vector>
#include<bitset>
#include<algorithm>
using namespace std;
const int N=1e6+10;
int sum[N];
std::vector<std::vector<int>> ChildList(N);
std::vector<std::vector<int>> SumList(N);
void SumBinary(int n){
for(int i=1;i<=n;i++){
std::vector<int> tmp;
swap(tmp,SumList[i]);
sort(tmp.begin(),tmp.end());
int sz=tmp.size();
for(int k=0;k<sz;k++){
int j=k;
while(j<sz-1&&tmp[j+1]==tmp[k]) j++;
int sum_k=j-k+1,Bin=1;
while(sum_k>=Bin){
sum_k-=Bin;
SumList[i].push_back(Bin*tmp[k]);
Bin<<=1;
}
if(sum_k) SumList[i].push_back(sum_k*tmp[k]);
k=j;
}
}
}
void dfs(const int &root){
sum[root]=1;
if(ChildList[root].empty()) return ;
for(const auto &node:ChildList[root]){
dfs(node);
sum[root]+=sum[node];
SumList[root].push_back(sum[node]);
}
}
template<int len=1>
int BitSet(int root){
if(sum[root]/2>=len){
return BitSet<min(len*2,N)>(root);
}
if(SumList[root].empty()) return 0;
if(SumList[root].back()>=sum[root]/2)
return SumList[root].back();
bitset<len> b(1);
int sz=SumList[root].size();
for(int i=0;i<sz;i++){
b|=b<<SumList[root][i];
}
for(int i=sum[root]/2;i>=0;i--){
if(b[i]) return i;
}
return 0;
}
long long Bag(int n){
long long ans=0;
for(int i=1;i<=n;i++){
int bag=BitSet<1>(i);
ans+=(sum[i]-1LL-bag)*bag;
}
return ans;
}
int read(){
int ans=0;
char c=getchar();
while(c<'0'||c>'9') c=getchar();
while(c>='0'&&c<='9'){
ans=ans*10+c-'0';
c=getchar();
}
return ans;
}
void solve(){
int n;
long long ans=0;
n=read();
for(int i=2;i<=n;i++){
int x;x=read();
ChildList[x].push_back(i);
}
dfs(1);
SumBinary(n);
ans=Bag(n);
printf("%lld\n",ans);
}
int main(){
int t;
t=1;
while(t--)
solve();
}