http://acm.hdu.edu.cn/showproblem.php?pid=4635
我们把缩点后的新图(实际编码中可以不建新图 只是为了概念上好理解)中的每一个点都赋一个值
表示是由多少个点缩成的 我们需要找所有端点 也可能出发点(只有出度) 也可能是结束点 (只有入度)
这个端点和外界(其它所有点)的联通性是单向的(只入或只出) 也可能没有联通
在保持这个端点与外界的单向联通性的情况下 任意加边
所以 当端点的值越小(包含点越少) 结果越优
代码:
#include<iostream> #include<cstdio> #include<algorithm> #include<string> #include<cstring> #include<cmath> #include<set> #include<vector> #include<list> #include<stack> #include<queue> using namespace std; typedef pair<int,int> pp; typedef long long ll; const int N=100005; const int M=100005; int head[N],I; struct node { int j,next; }edge[M]; int low[N],dfn[N],f[N],deep; bool in[N],visited[N]; stack<int>st; pp p[M]; void add(int i,int j) { edge[I].j=j; edge[I].next=head[i]; head[i]=I++; } bool ok(vector<int>& vt) { for(unsigned int i=0;i<vt.size();++i) { int x=vt[i]; for(int t=head[x];t!=-1;t=edge[t].next) { int y=edge[t].j; if(f[x]!=f[y]) return false; } } return true; } void tarjan(int x,int &M) { visited[x]=true; in[x]=true; st.push(x); low[x]=dfn[x]=deep++; for(int t=head[x];t!=-1;t=edge[t].next) { int j=edge[t].j; if(visited[j]==false) { tarjan(j,M); low[x]=min(low[x],low[j]); }else if(in[j]==true) { low[x]=min(low[x],dfn[j]); } } if(low[x]==dfn[x]) { vector<int>vt; int tmp=1; while(st.top()!=x) { int k=st.top(); st.pop(); vt.push_back(k); in[k]=false; f[k]=x; ++tmp; } int k=st.top(); st.pop(); vt.push_back(k); in[k]=false; f[k]=x; if(ok(vt)) { M=min(M,tmp); } } } void init(int n,int m) { memset(head,-1,sizeof(head)); I=0; for(int i=0;i<m;++i) add(p[i].first,p[i].second); } int solve(int n,int m) { init(n,m); while(!st.empty()) st.pop(); for(int i=1;i<=n;++i) {f[i]=i;} memset(in,false,sizeof(in)); memset(visited,false,sizeof(visited)); deep=0; int k=n+1; for(int i=1;i<=n;++i) if(!visited[i]) tarjan(i,k); return k; } int main() { //freopen("data.in","r",stdin); int T; scanf("%d",&T); for(int ca=1;ca<=T;++ca) { printf("Case %d: ",ca); int n,m; scanf("%d %d",&n,&m); for(int i=0;i<m;++i) scanf("%d %d",&p[i].first,&p[i].second); int k=solve(n,m); for(int i=0;i<m;++i) swap(p[i].first,p[i].second); k=min(solve(n,m),k); if(k==n) {cout<<"-1"<<endl;continue;} ll ans=0; ans=(ll)(n)*(ll)(n-1); ans-=m; ans-=(ll)(k)*(ll)(n-k); cout<<ans<<endl; } return 0; }