AOJ2377 ThreeRooks

これで1200+以外は1問以上解いた事になった




概要

X*YのマスにK匹のうさぎがいる。互いに攻撃できないように3個のルークを置く場合の数をmod 10^9+7で求めよ

X,Y<=10^9,K<=10^5



解法

3個同じ所に並んでいる場合

2個並んでいて1個はその範囲にはない

3個が直角に並んでいる場合

を数えると解ける。
1,2個目は普通に求まって, 3個目は気合で平面走査する
平面走査では、区間全体にある値を足す,あるところを0にする,ある範囲の和を求める
が出来ると良いのでよくあるsegment treeを使う(座標圧縮はそんなに面倒でもない)

こういうのはvectorをもらって計算する関数を作っておくと、座標を入れ替えるだけで良くて嬉しみがある

#include <bits/stdc++.h>
using namespace std;
typedef pair<int, int> pii;
typedef long long ll;
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define rep(i,n) for(int i=0;i<n;++i)

const ll MOD=1000000007;
ll ret,p2,p6;
ll X,Y,K,zan;

inline ll c2(ll x){return x*(x-1)%MOD*p2%MOD;}
inline ll c3(ll x){return x*(x-1)%MOD*(x-2)%MOD*p6%MOD;}
inline void md(ll &x,ll y){x+=y;if(x>=MOD)x-=MOD;}

vector<int> xs;
inline int p(int x){return lower_bound(xs.begin(),xs.end(),x)-xs.begin();}
int sz;
struct segtree{
    vector<ll> dat1, dat2;
    void init(int n_){
	for(sz=1;sz<n_;sz*=2);
	dat1=vector<ll>(sz*2,0);
	dat2=vector<ll>(sz*2,0);  
    }
    void add(int a,int b,ll x,int k=0,int l=0,int r=sz){
	if(b<=l||r<=a)return ;
	if(a<=l&&r<=b)md(dat1[k],x);
	else{
	    md(dat2[k],x*(xs[min(b,r)]-xs[max(a,l)])%MOD);
	    add(a,b,x,k*2+1,l,(l+r)/2);
	    add(a,b,x,k*2+2,(l+r)/2,r);
	}
    }

    ll sum(int a,int b,int k=0,int l=0,int r=sz){
	if(b<=l||r<=a)return 0;
	if(a<=l&&r<=b)return (dat1[k]*(xs[r]-xs[l])+dat2[k])%MOD;
	else{
	    ll res=dat1[k]*(xs[min(b,r)]-xs[max(a,l)]);
	    res+=sum(a,b,k*2+1,l,(l+r)/2);
	    res+=sum(a,b,k*2+2,(l+r)/2,r);
	    return res%MOD;
	}
    }
}seg;

ll calc_23(ll h,ll w,vector<pii>& pt){
    ll s=0,my=h;
    sort(pt.begin(),pt.end());
    for(int i=0;i<pt.size();--my){
	int y=pt[i].fi,j=i;
	while(j<pt.size()&&pt[j].fi==y)++j;
	vector<int> v(j-i);
	for(int k=i;k<j;++k)v[k-i]=pt[k].se;
	md(s,(c3(v[0])+c2(v[0])*(zan-v[0]))%MOD);
	for(int k=1;k<v.size();++k){
	    ll t=v[k]-v[k-1]-1;
	    md(s,c3(t)+c2(t)*(zan-t)%MOD);
	}
	ll t=w-1-v.back();
	md(s,c3(t)+c2(t)*(zan-t)%MOD);
	i=j;
    }
    return (s+(my*c3(w)+(zan-w)*c2(w)%MOD*my))%MOD;
}

ll calc_L(ll h,ll w,vector<pii>& pt){
    sort(pt.begin(),pt.end());
    xs.resize(pt.size()*2);
    rep(i,pt.size()){xs[i*2]=pt[i].se;xs[i*2+1]=pt[i].se+1;}
    xs.pb(0);xs.pb(w);
    sort(xs.begin(),xs.end());
    xs.erase(unique(xs.begin(),xs.end()),xs.end());
    seg.init(xs.size());
    int rr=xs.size()-1;

    ll s=0,la=-1;
    for(int i=0;i<pt.size();){
	int y=pt[i].fi,j=i;
	ll g=seg.sum(0,rr);

	seg.add(0,rr,y-la-1);
	ll ng=seg.sum(0,rr);
	md(s,((w-1)*(g+ng-w)%MOD*(y-la-1)%MOD*p2)%MOD);
	while(j<pt.size()&&pt[j].fi==y)++j;
	vector<int> v(j-i);
	for(int k=i;k<j;++k)v[k-i]=pt[k].se;

	if(v[0]>1)md(s,(v[0]-1)*seg.sum(0,p(v[0]))%MOD);

	for(int k=1;k<v.size();++k){
	    ll t=v[k]-v[k-1]-1;
	    if(t>1)md(s,(t-1)*seg.sum(p(v[k-1]+1),p(v[k]))%MOD);
	}
	ll t=w-1-v.back();
	if(t>1)md(s,(t-1)*seg.sum(p(v.back()+1),rr)%MOD);
	seg.add(0,rr,1);
	for(int k=i;k<j;++k){
	    int l=p(pt[k].se);
	    ll x=seg.sum(l,l+1);
	    seg.add(l,l+1,-x);
	}
	i=j;
	la=y;
    }
    if(la!=h-1){
	ll num=h-1-la,t1=seg.sum(0,rr),t2=(t1+w*(num-1))%MOD;
	md(s,(w-1)*(t1+t2)%MOD*num%MOD*p2%MOD);
    }
    return s;
}

int main(){
    p2=(MOD+1)/2;p6=(MOD+1)/6;
    scanf("%lld%lld%lld",&X,&Y,&K);
    vector<pii> pt(K);
    rep(i,K){
	int x,y;
	scanf("%d%d",&x,&y);
	pt[i]=mp(x,y);
    }
    zan=(X*Y-K)%MOD;
    if(zan<3){puts("0");return 0;}
    ret=c3(zan);
    md(ret,MOD-calc_23(X,Y,pt));
    rep(i,K)swap(pt[i].fi,pt[i].se);
    md(ret,MOD-calc_23(Y,X,pt));
    rep(i,K)swap(pt[i].fi,pt[i].se);
    ret+=calc_L(X,Y,pt);
    rep(i,K)pt[i].fi=X-1-pt[i].fi;
    ret+=calc_L(X,Y,pt);
    printf("%d\n",ret%MOD);
    return 0;
}