JOI春合宿 倉庫番(Sokoban)

グリッドグラフだ〜〜〜という気持ちになって頑張る.
コードだけ上げておきます O(MN)

#include <cstdio>
#include <algorithm>
#include <cstring>

using namespace std;

typedef long long ll;

#define rep(i, n) for(int i = 0; i < (int)n; ++i)

const int dx[4] = {1, -1, 0, 0};
const int dy[4] = {0, 0, 1, -1};

int n, m, sx, sy, zen;
ll ret;
int deg[1000010];
int dt[1000010];
int g[1000010][4];
int tr[1000010][4];
char ban[1010][1010];
bool can[1010][1010];
bool vis[1000010];
bool inv[1010][1010][4];
int ord[1000010], out[1000010], low[1000010];
int sz[1000010];
int par[1000010];
int now;
bool adj[1000010][4][4];

inline void pre(int y, int x){
    can[y][x] = true;
    rep(i, 4){
	int ny = y + dy[i];
	int nx = x + dx[i];
	if(ny >= 0 && ny < m && nx >= 0 && nx < n && ban[ny][nx] != '#' && !can[ny][nx]) pre(ny, nx);
    }
}

void pro(int v, int p){
    low[v] = ord[v] = now++;
    sz[v] = 1;
    par[v] = p;
    vis[v] = true;
    rep(i, deg[v]) if(g[v][i] != p){
	if(!vis[g[v][i]]){
	    tr[v][dt[v]++] = g[v][i];
	    pro(g[v][i], v);
	    sz[v] += sz[g[v][i]];
	    low[v] = min(low[v], low[g[v][i]]);
	}else{
	    low[v] = min(low[v], ord[g[v][i]]);
	}
    }
    out[v] = now++;
}

inline bool dec(int p, int q){ // p : ancestor of q
    return ord[p] <= ord[q] && out[q] <= out[p];
}

void set(int v){ //remove v
    int ap[4] = {-1};
    rep(i, deg[v]){
	int a = g[v][i];	
	if(dec(v, a)){
	    rep(j, dt[v]){
		if(dec(tr[v][j], a)){
		    ap[i] = tr[v][j];
		    break;
		}
	    }
	}
    }

    rep(i, deg[v]){
	int a = g[v][i];
	rep(j, deg[v]) if(i != j){
	    int b = g[v][j];
	    if(!(dec(v, a) || dec(v, b))) adj[v][i][j] = adj[v][j][i] = true;
	    else if(dec(v, a) && !dec(v, b)){
		adj[v][i][j] = adj[v][j][i] = (low[ap[i]] < ord[v]);
	    }else if(dec(v, a) && dec(v, b)){
		if(ap[i] == ap[j] || (low[ap[i]] < ord[v] && low[ap[j]] < ord[v])){
		    adj[v][i][j] = adj[v][j][i] = true;
		}
	    }
	}
    }
}

void calc(int y, int x, int dir){
    if(inv[y][x][dir]) return ;
    //printf("y : %d, x : %d, dir : %d\n", y, x, dir);
    inv[y][x][dir] = true;
    int v = y * n + x;
    if(!can[y][x]) set(v);
    can[y][x] = true;
    int px = x + dx[dir], py = y + dy[dir];
    int nv = py * n + px;
    int whi = -1;
    rep(i, deg[v]){
	if(g[v][i] == nv){
	    whi = i;
	    break;
	}
    }

    int kou[4];
    int piv = 0;
    //printf("whi : %d\n", whi);
    ll mi = 0;
    bool f = false;
    //printf("y : %d, x : %d, py : %d, px : %d\n", y, x, py, px);
    rep(i, deg[v]) {
	if((i == whi) || adj[v][whi][i]){
	    int dev = g[v][i];
	    if(dec(g[v][i], v)) f = true;
	    //printf("(%d %d) ", dev / n, dev % n);
	    kou[piv++] = g[v][i];
	}
    }
    //puts("");

    ll tmp = 0;
    if(par[v] != -1){
	rep(i, piv){
	    if(par[v] == kou[i]){
		tmp += zen - sz[v] - 1;
	    }
	}
    }

    rep(i, dt[v]){
	rep(j, piv){
	    if(tr[v][i] == kou[j]) tmp += sz[tr[v][i]];
	}
    }
    
    if(y != sy || x != sx) ret += tmp;
    //printf("y : %d, x : %d, dir : %d, add : %lld\n", y, x, dir, tmp);

    int d[4] = {-1};

    rep(i, piv){
	int ny = kou[i] / n, nx = kou[i] % n;
	int ndir = -1;
	int dify = ny - y, difx = nx - x;
	if(difx == 1 && dify == 0) ndir = 0;
	else if(difx == -1 && dify == 0) ndir = 1;
	else if(difx == 0 && dify == 1) ndir = 2;
	else ndir = 3;
	inv[y][x][ndir] = true;
	d[i] = ndir;
    }

    rep(i, piv){
	int ny = kou[i] / n, nx = kou[i] % n;
	py = ny + dy[d[i]];
	px = nx + dx[d[i]];
	if(px < 0 || py < 0 || px >= n || py >= m || ban[py][px] == '#') continue;
	//printf("(py, px) : (%d, %d), (ny, nx) : (%d, %d), ndir : %d\n", py, px, ny, nx, ndir);
	calc(ny, nx, d[i]);
    }
}

int main(){
    scanf("%d %d", &m, &n);
    rep(i, m){
	scanf("%s", ban[i]);
	rep(j, n) if(ban[i][j] == 'X'){
	    sy = i;
	    sx = j;
	    ban[i][j] = '.';
	}
    }
    pre(sy, sx);
    rep(i, m) rep(j, n) if(!can[i][j]) ban[i][j] = '#';
    memset(can, 0, sizeof(can));
    rep(i, m) rep(j, n){
	int y = i, x = j, v = i * n + j;
	if(ban[i][j] == '#') continue;
	rep(k, 4){
	    int ny = y + dy[k], nx = x + dx[k], nv = ny * n + nx;
	    if(ny >= 0 && ny < m && nx >= 0 && nx < n && ban[ny][nx] != '#'){
		g[v][deg[v]++] = nv;
	    }
	}
    }
    pro(n * sy + sx, -1);
    zen = sz[n * sy + sx];
    rep(i, 4){
	int nx = sx + dx[i], ny = sy + dy[i];
	if(nx < 0 || ny < 0 || nx >= n || ny >= m || ban[ny][nx] == '#') continue;
	int px = nx + dx[i], py = ny + dy[i];
	if(px < 0 || py < 0 || px >= n || py >= m || ban[py][px] == '#') continue;
	calc(ny, nx, i);
    }
    printf("%lld\n", ret);
    return 0;
}