Skip to content

F2

C++

template<int bitsize=60>
struct F2 {
    /*
     * Ax = b となる x の解を求める
    */
    std::vector<std::bitset<bitsize>> A;
    std::vector<int> b;

    // 方程式を追加する
    void add_equation(std::bitset<bitsize> Ai, bool bi) {
        A.push_back(Ai);
        b.push_back(1 * bi);
    }

    // 方程式を解く
    // 解が存在する (true) か否か (false) を返す
    bool solve() {
        int n = A.size(), m = bitsize, row = 0;
        for(int j = 0; j < m; j++) {
            for(int i = row; i < n; i++) {
                if(A[i][j]) {
                    std::swap(A[row], A[i]);
                    std::swap(b[row], b[i]);
                    for(int k = 0; k < n; k++) {
                        if(k == row) continue;
                        if(A[k][j]) {
                            A[k] ^= A[row];
                            b[k] ^= b[row];
                        }
                    }
                    row++;
                    break;
                }
            }
        }

        // 解の存在判定
        for(int i = n - 1; i >= 0; i--) {
            if(A[i].any()) break;
            if(b[i] != 0) return false;
        }
        return true;
    }

    int rank() {
        int r = A.size(), n = r;
        for(int i = n - 1; i >= 0; i--) {
            if(A[i].any()) break;
            r--;
        }
        return r;
    }

    // num の数だけ解を列挙する (0 <= num <= min(2^20, 2^{n-rank}))
    std::vector<std::bitset<bitsize>> enumerate_solutions(int num) {
        // 1 が立ってる場所を求める
        std::vector<int> one;
        int n = A.size(), m = bitsize, pos = 0;
        for(int i = 0; i < n; i++) {
            while(pos < m and !A[i][pos]) {
                pos++;
            }
            if(pos < m) one.push_back(pos);
        }
        one.push_back(m);

        // 自由に決められる x を求める
        std::vector<int> is_free(m, 0);
        int r = m; // rank
        pos = 0;
        for(int i : one) {
            while(pos < i) {
                is_free[pos++] = 1;
                r--;
            }
            pos++;
        }

        std::vector<std::bitset<bitsize>> solutions;
        int p = std::min(20, r), i = r - 1;
        for(int s = 0; s < std::min(num, 1 << p); s++) {
            std::bitset<bitsize> solution;
            for(int j = m - 1; j >= 0; j--) {
                if(is_free[j]) {
                    if(s >> j & 1) {
                        solution.set(j);
                    }
                } else {
                    int sum = 0;
                    for(int k = j + 1; k < m; k++) {
                        sum ^= A[i][k] * solution[k];
                    }
                    if(sum ^ b[i]) solution.set(j);
                    i--;
                }
            }
            solutions.push_back(solution);
        }

        return solutions;
    }
};

Examples