Official

G - Spanning Tree Editorial by en_translator


First, add an edge between \((i, j)\) such that \(A_{i, j} = 1\).
If \(G\) contains a cycle after adding edges, then \(G\) can never be a tree, so the answer is \(0\).
Otherwise, one can compress each connected component of \(G\) into a single vertex, so that we now have yet another problem of adding edges between edges to obtain a tree.

Kirchhoff’s theorem

Kirchhoff’s matrix tree theorem claims that any cofactor of Laplacian matrix of an undirected graph \(G\) is equal to the number of spanning trees of \(G\).

Since this problem is equivalent to finding the number of spanning trees of an undirected graph where the compressed connected components are regarded as vertices and those pair of \(A_{i,j}=-1\) are regarded as edges, so one can solve this problem by finding the Laplacian matrix of the graph, removing the last column and the last row, and then finding the determinant in \(\bmod (10^9+7)\).
The total time complexity is \(O(N^3)\).

Sample Code (C++)

#include <iostream>
#include <vector>
#include <valarray>
#include <atcoder/modint>
#include <atcoder/dsu>
using namespace std;
using Modint = atcoder::modint1000000007;

int main(){
    int N;
    cin >> N;
    vector A(N, vector<int>(N));
    for(auto& v : A) for(int& i : v) cin >> i;
    atcoder::dsu uf(N);
    for(int i = 0; i < N; i++) for(int j = i + 1; j < N; j++) if(A[i][j] == 1){
        if(uf.same(i, j)) return puts("0") & 0;
        uf.merge(i, j);
    }
    vector m(N, valarray(Modint(0), N));
    for(int i = 0; i < N; i++) for(int j = i + 1; j < N; j++) if(A[i][j] == -1){
        const int x = uf.leader(i), y = uf.leader(j);
        m[x][x]++; m[y][y]++; m[x][y]--; m[y][x]--;
    }
    for(int i = 0; i < N; i++){
        const int x = uf.leader(i), y = i;
        m[x][x]++; m[y][y]++; m[x][y]--; m[y][x]--;
    }

    N--;
    Modint ans = 1;
    for(int i = 0; i < N; i++){
        if(m[i][i].val() == 0){
            for(int j = i + 1; j < N; j++) if(m[j][i].val()){
                swap(m[i], m[j]);
                ans *= -1;
                break;
            }
            if(m[i][i].val() == 0) return puts("0") & 0;
        }
        ans *= m[i][i];
        m[i] *= m[i][i].inv();
        for(int j = i + 1; j < N; j++) m[j] -= m[i] * m[j][i];
    }
    cout << ans.val() << endl;
}

Sample Code (Python)

class UnionFind:
    def __init__(self, size):
        self.data = [-1] * size
    def root(self, x):
        if self.data[x] < 0:
            return x
        ans = self.root(self.data[x])
        self.data[x] = ans
        return ans
    def unite(self, x, y):
        x = self.root(x)
        y = self.root(y)
        if x == y:
            return False
        if self.data[x] > self.data[y]:
            x, y = y, x
        self.data[x] += self.data[y]
        self.data[y] = x
        return True
    def size(self, x):
        return -self.data[self.root(x)]

MOD = 1000000007
N = int(input())
A = [list(map(int, input().split())) for i in range(N)]

uf = UnionFind(N)
for i in range(N):
    for j in range(i):
        if A[i][j] == 1:
            if not uf.unite(i, j):
                exit(print(0))

m = [[0] * N for i in range(N)]
for i in range(N):
    for j in range(i):
        if A[i][j] == -1:
            x = uf.root(i)
            y = uf.root(j)
            m[x][x] += 1
            m[y][y] += 1
            m[x][y] -= 1
            m[y][x] -= 1

for i in range(N):
    x = uf.root(i)
    y = i
    m[x][x] += 1
    m[y][y] += 1
    m[x][y] -= 1
    m[y][x] -= 1

N -= 1
ans = 1

for i in range(N):
    if m[i][i] == 0:
        for j in range(i + 1, N):
            if m[j][i] == 0:
                continue
            m[i], m[j] = m[j], m[i]
            ans *= -1
            break
        else:
            exit(print(0))
    ans *= m[i][i]
    ans %= MOD
    inv = pow(m[i][i], MOD - 2, MOD)
    for j in range(i, N):
        m[i][j] *= inv
        m[i][j] %= MOD
    for j in range(i + 1, N):
        x = m[j][i]
        for k in range(i, N):
            m[j][k] -= m[i][k] * x
            m[j][k] %= MOD

print(ans)

posted:
last update: