MatrixClass

矩阵类,使用朴素矩阵乘法算法,有一定常数优化,四重循环展开

const int MOD = 1e9 + 7;

template <typename T>
class Matrix
{
public:
    vector<T> arr;
    int N, NN;
    T e, z;
#define __e 1
#define __z 0
    //大小,乘法单位元,加法单位元
    Matrix(int n, T _e = __e, T _z = __z)
    {
        N = n;
        NN = N * N;
        e = _e;
        z = _z;
        arr.resize(N * N);
    }
    Matrix(const Matrix &other)
    {
        N = other.N;
        NN = N * N;
        e = other.e;
        z = other.z;
        arr = other.arr;
    }
    Matrix(const Matrix *other)
    {
        N = other->N;
        NN = N * N;
        e = other->e;
        z = other->z;
        arr = other->arr;
    }
#undef __e
#undef __z
    ~Matrix()
    {
        vector<T> _tmp;
        arr.swap(_tmp);
    }
    Matrix &operator=(const Matrix &other)
    {
        N = other.N;
        NN = N * N;
        e = other.e;
        z = other.z;
        arr = other.arr;
        return *this;
    }
    inline static void Maintain(T &x)
    {
        x = x % MOD;
        if (x < 0)
            x += MOD;
    }

    //不做越界检查
    inline T &operator()(int i, int j)
    {
        return arr[i * N + j];
    }

    // 保证大小相同
    inline Matrix operator+(const Matrix &other) const
    {
        Matrix ret(N, e, z);
        for (int i = 0; i < NN; i++)
        {
            ret.arr[i] = arr[i] + other.arr[i];
            Maintain(ret.arr[i]);
        }
        return ret;
    }
    inline Matrix operator-(const Matrix &other) const
    {
        Matrix ret(N, e, z);
        for (int i = 0; i < NN; i++)
        {
            ret.arr[i] = arr[i] - other.arr[i];
            Maintain(ret.arr[i]);
        }
        return ret;
    }

    // Guaranteed that no exception
    inline Matrix operator*(const Matrix &other)
    {
        Matrix ret(N, e, z);
        if (N <= 1024)
        {
            T _buf;
            int iN, kN_j, iN_j;
            for (int i = 0; i < N; ++i)
            {
                iN = i * N;
                for (int k = 0; k < N; ++k)
                {
                    _buf = arr[iN + k];
                    iN_j = iN;
                    kN_j = k * N;
                    int j;
                    for (j = 0; j + 4 < N; j += 4, iN_j += 4, kN_j += 4)
                    {
                        ret.arr[iN_j] = ret.arr[iN_j] + _buf * other.arr[kN_j];
                        Maintain(ret.arr[iN_j]);
                        ret.arr[iN_j + 1] = ret.arr[iN_j + 1] + _buf * other.arr[kN_j + 1];
                        Maintain(ret.arr[iN_j + 1]);
                        ret.arr[iN_j + 2] = ret.arr[iN_j + 2] + _buf * other.arr[kN_j + 2];
                        Maintain(ret.arr[iN_j + 2]);
                        ret.arr[iN_j + 3] = ret.arr[iN_j + 3] + _buf * other.arr[kN_j + 3];
                        Maintain(ret.arr[iN_j + 3]);
                    }
                    for (; j < N; ++j, ++iN_j, ++kN_j)
                    {
                        ret.arr[iN_j] = ret.arr[iN_j] + _buf * other.arr[kN_j];
                        Maintain(ret.arr[iN_j]);
                    }
                }
            }
            return ret;
        }
        else
        {
            //较大则分治
            Matrix *A = this->FillToFit();
            Matrix *B = other.FillToFit();
            Matrix *A11 = A->Quarter(0);
            Matrix *A12 = A->Quarter(1);
            Matrix *A21 = A->Quarter(2);
            Matrix *A22 = A->Quarter(3);

            Matrix *B11 = B->Quarter(0);
            Matrix *B12 = B->Quarter(1);
            Matrix *B21 = B->Quarter(2);
            Matrix *B22 = B->Quarter(3);

            Matrix &&M1 = (*A11 + *A22) * (*B11 + *B22);
            Matrix &&M2 = (*A21 + *A22) * (*B11);
            Matrix &&M3 = (*A11) * (*B12 - *B22);
            Matrix &&M4 = (*A22) * (*B21 - *B11);
            Matrix &&M5 = (*A11 + *A12) * (*B22);
            Matrix &&M6 = (*A21 - *A11) * (*B11 + *B12);
            Matrix &&M7 = (*A12 - *A22) * (*B21 + *B22);
            Matrix *Ans = Merge(M1 + M4 - M5 + M7, M3 + M5, M2 + M4, M1 - M2 + M3 + M6);
            auto Ansarr = Ans->arr.begin();
            int AnsN = Ans->N;
            int iN_j = 0, in_j;
            for (int i = 0; i < N; ++i)
            {
                in_j = i * AnsN;
                for (int j = 0; j < N; ++j, ++iN_j, ++in_j)
                {
                    ret.arr[iN_j] = Ansarr[in_j];
                }
            }
            {
                delete A;
                delete B;
                delete A11;
                delete A12;
                delete A21;
                delete A22;
                delete B11;
                delete B12;
                delete B21;
                delete B22;
                delete Ans;
            }
            return ret;
        }
    }

    inline Matrix QuickPow(long long k)
    {
        Matrix ret(N, e, z);
        ret.GetE();
        Matrix *mul = new Matrix(this);
        while (k > 0)
        {
            if (k & 1)
                ret = ret * (*mul);
            *mul = (*mul) * (*mul);
            k >>= 1;
        }
        delete mul;
        return ret;
    }

    void GetZero()
    {
        int NN = N * N;
        for (int i = 0; i < NN; i++)
            arr[i] = z;
    }
    void GetE()
    {
        int iN_j = 0;
        for (int i = 0; i < N; i++)
            for (int j = 0; j < N; j++, ++iN_j)
                arr[iN_j] = (i == j) ? e : z;
    }

private:
    //返回四分之一分块矩阵,需要保证N为2的幂,pos = 0~3 分别为 左上,右上,左下,右下
    Matrix *Quarter(int pos) const
    {
        Matrix *ret = new Matrix(N / 2, e, z);
        int offsetx, offsety;
        offsetx = ((pos & 2) > 0) * (N / 2);
        offsety = ((pos & 1) > 0) * (N / 2);
        int offset = offsetx * N + offsety;
        //0~3 分别为 左上,右上,左下,右下
        int retN = ret->N;
        auto retarr = ret->arr.begin();
        int iretN_j = 0;
        int iN_j_offset;
        for (int i = 0; i < retN; i++)
        {
            iN_j_offset = i * N + offset;
            for (int j = 0; j < retN; j++, ++iN_j_offset, ++iretN_j)
            {
                retarr[iretN_j] = arr[iN_j_offset];
            }
        }
        return ret;
    }

    //添加空元素以适应其大小为2的幂
    inline Matrix *FillToFit() const
    {
        int n = 1;
        while (n < N)
            n <<= 1;
        Matrix *ret = new Matrix(n, e, z);
        ret->GetZero();
        auto retarr = ret->arr.begin();
        int in_j, iN_j = 0;
        for (int i = 0; i < N; i++)
        {
            in_j = i * n;
            for (int j = 0; j < N; j++, ++in_j, ++iN_j)
            {
                retarr[in_j] = arr[iN_j];
            }
        }
        return ret;
    }

    //需要保证大小相同,分别为 左上,右上,左下,右下
    inline static Matrix *Merge(const Matrix &LU, const Matrix &RU, const Matrix &LD, const Matrix &RD)
    {
        Matrix *ret = new Matrix(2 * LU.N, LU.e, LU.z);
        int n = LU.N;
        int nn = n * n;
        int nn_n = nn + n;
        auto retarr = ret->arr.begin();
        int iN_j = 0;
        int retN = ret->N;
        for (int i = 0; i < retN; i++)
        {
            int in_j = i * n;
            for (int j = 0; j < retN; ++j, ++in_j, ++iN_j)
            {
                if (i < n)
                {
                    if (j < n)
                        retarr[iN_j] = LU.arr[in_j];
                    else
                        retarr[iN_j] = RU.arr[in_j - n];
                }
                else
                {
                    if (j < n)
                        retarr[iN_j] = LD.arr[in_j - nn];
                    else
                        retarr[iN_j] = RD.arr[in_j - nn_n];
                }
                Maintain(retarr[iN_j]);
            }
        }
        return ret;
    }
};