Submission #2997335


Source Code Expand

Copy
// package other2018.mujinpc2018;

import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.math.BigInteger;
import java.util.*;

public class Main {
    static long __startTime = System.currentTimeMillis();

    public static void main(String[] args) {
        InputReader in = new InputReader(System.in);
        PrintWriter out = new PrintWriter(System.out);


        H = in.nextInt();
        W = in.nextInt();
        patternSets = new HashSet<>();
        patternSets.add(1);
        memo = new HashMap<>();

        while (true) {
            Set<Integer> nextPatternSets = new HashSet<>();
            nextPatternSets.addAll(patternSets);
            for (int p : patternSets) {
                for (int t :buildTable(p)) {
                    nextPatternSets.add(t);
                }
            }
            if (nextPatternSets.size() == patternSets.size()) {
                break;
            }
            patternSets = new HashSet<>(nextPatternSets);
        }
        ptoi = new HashMap<>();
        for (int k : patternSets) {
            ptoi.put(k, ptoi.size());
        }

        int sz = ptoi.size();
        long[][] tbl = new long[sz][sz];
        for (int k : patternSets) {
            int id = ptoi.get(k);
            int[] to = memo.get(k);
            for (int i = 0; i < to.length; i++) {
                int tid = ptoi.get(to[i]);
                tbl[tid][id]++;
            }
        }

        int zero = ptoi.get(1);
        long[][] tblW = Matrix.pow(tbl, W);
        long ans = 0;
        for (int k : patternSets) {
            if ((k & 1) == 1) {
                ans += tblW[ptoi.get(k)][zero];
            }
        }
        out.println(ans % MOD);
        out.flush();
    }

    public static class Matrix {
        static final long MOD2 = (long)MOD * MOD * 8;

        static long pow(long a, long x) {
            long res = 1;
            while (x > 0) {
                if (x % 2 != 0) {
                    res = (res * a) % MOD;
                }
                a = (a * a) % MOD;
                x /= 2;
            }
            return res;
        }

        static long[][] e(int n) {
            long[][] mat = new long[n][n];
            for (int i = 0; i < n ; i++) {
                mat[i][i] = 1;
            }
            return mat;
        }

        static long inv(long a) {
            return pow(a, MOD - 2) % MOD;
        }

        static void swapRow(long[][] x, int p, int q) {
            int n = x[0].length;
            for (int i = 0; i < n ; i++) {
                long tmp = x[p][i];
                x[p][i] = x[q][i];
                x[q][i] = tmp;
            }
        }

        static void addRow(long[][] x, int p, int q, long mul) {
            int n = x[0].length;
            mul = mul < 0 ? MOD+mul : mul;
            for (int i = 0; i < n ; i++) {
                long add = x[q][i] * mul % MOD;
                x[p][i] += add;
                if (x[p][i] >= MOD) {
                    x[p][i] -= MOD;
                }
            }
        }

        static void mulRow(long[][] x, int p, long mul) {
            int n = x[0].length;
            for (int i = 0; i < n ; i++) {
                long to = (x[p][i] * mul % MOD);
                x[p][i] = (int)to;
            }
        }

        static long[][] inv(long[][] x) {
            int n = x.length;

            long[][] fr = new long[n][n];
            for (int i = 0; i < n ; i++) {
                for (int j = 0; j < n ; j++) {
                    fr[i][j] = x[i][j];
                }
            }
            long[][] to = new long[n][n];
            for (int i = 0; i < n ; i++) {
                to[i][i] = 1;
            }
            for (int i = 0; i < n ; i++) {
                int pos = i;
                while (pos < n && fr[pos][i] == 0) {
                    pos++;
                }
                if (pos != i) {
                    swapRow(fr, pos, i);
                    swapRow(to, pos, i);
                }

                long kake = inv(fr[i][i]);
                mulRow(fr, i, kake);
                mulRow(to, i, kake);

                for (int j = 0; j < n ; j++) {
                    if (i != j) {
                        long bai = -fr[j][i];
                        addRow(to, j, i, bai);
                        addRow(fr, j, i, bai);
                    }
                }
            }
            return to;
        }

        static long[][] pow(long[][] x, long p) {
            int n = x.length;
            long[][] ret = e(n);
            while (p >= 1) {
                if ((p & 1) == 1) {
                    ret = mul(ret, x);
                }
                x = mul(x, x);
                p >>>= 1;
            }
            return ret;
        }

        static long[][] mul(long[][] a, long[][] b) {
            int n = a.length;
            int k = a[0].length;
            int m = b[0].length;

            long[][] ret = new long[n][m];
            long[] row = new long[m];
            for (int i = 0; i < n ; i++) {
                Arrays.fill(row, 0);
                for (int l = 0; l < k ; l++) {
                    for (int j = 0; j < m ; j++) {
                        row[j] += a[i][l] * b[l][j];
                        if (row[j] >= MOD2) {
                            row[j] -= MOD2;
                        }
                    }
                }
                for (int j = 0; j < m ; j++) {
                    ret[i][j] = row[j] % MOD;
                }
            }
            return ret;
        }
    }




    static Set<Integer> patternSets;

    static int H;
    static int W;

    static Map<Integer,Integer> ptoi;
    static Map<Integer, int[]> memo;

    public static int[] buildTable(int lastPatternSet) {
        if (memo.containsKey(lastPatternSet)) {
            return memo.get(lastPatternSet);
        }

        int[] nextSet = new int[1<<H];
        for (int p = 0 ; p < (1<<H) ; p++) {
            int nextPatternSet = 0;
            for (int last = 0 ; last < (1<<H) ; last++) {
                if ((lastPatternSet & (1<<last)) != 0) {
                    if ((p & last) != last) {
                        continue;
                    }
                    int needToFill = p ^ last;
                    nextPatternSet |= placePattern(needToFill);
                }
            }
            nextSet[p] = nextPatternSet;
        }
        memo.put(lastPatternSet, nextSet);
        return nextSet;
    }

    public static int placePattern(int needToFill) {
        int p3 = (int)Math.pow(3, H);
        int patternSet = 0;
        for (int p = 0 ; p < p3 ; p++) {
            int[][] r = place(decode3(p));
            if (r == null) {
                continue;
            }
            int prow = toPattern(r[0]);
            if (prow != needToFill) {
                continue;
            }
            int pnext = toPattern(r[1]);
            patternSet |= 1<<pnext;
        }
        return patternSet;
    }

    public static int toPattern(int[] p) {
        int pt = 0;
        for (int i = 0; i < H ; i++) {
            if (p[i] == 1) {
                pt |= 1<<i;
            }
        }
        return pt;
    }

    public static int[][] place(int[] pt) {
        int[][] result = new int[2][H];
        for (int i = 0; i < H ; i++) {
            if (pt[i] == 1) {
                result[0][i]++;
                result[1][i]++;
            } else if (pt[i] == 2) {
                result[0][i]++;
                if (i+1 == H) {
                    return null;
                }
                result[0][i+1]++;
            }
        }
        for (int i = 0; i < H ; i++) {
            if (result[0][i] >= 2) {
                return null;
            }
        }
        return result;
    }

    public static int[] decode3(int p3) {
        int[] l = new int[H];
        for (int i = 0; i < H ; i++) {
            l[i] = p3 % 3;
            p3 /= 3;
        }
        return l;
    }

    static final int MOD = 998244353;

    private static void printTime(String label) {
        debug(label, System.currentTimeMillis() - __startTime);
    }

    private static void debug(Object... o) {
        System.err.println(Arrays.deepToString(o));
    }

    public static class InputReader {
        private static final int BUFFER_LENGTH = 1 << 12;
        private InputStream stream;
        private byte[] buf = new byte[BUFFER_LENGTH];
        private int curChar;
        private int numChars;

        public InputReader(InputStream stream) {
            this.stream = stream;
        }

        private int next() {
            if (numChars == -1) {
                throw new InputMismatchException();
            }
            if (curChar >= numChars) {
                curChar = 0;
                try {
                    numChars = stream.read(buf);
                } catch (IOException e) {
                    throw new InputMismatchException();
                }
                if (numChars <= 0)
                    return -1;
            }
            return buf[curChar++];
        }

        public char nextChar() {
            return (char) skipWhileSpace();
        }

        public String nextToken() {
            int c = skipWhileSpace();
            StringBuilder res = new StringBuilder();
            do {
                res.append((char) c);
                c = next();
            } while (!isSpaceChar(c));
            return res.toString();
        }

        public int nextInt() {
            return (int) nextLong();
        }

        public long nextLong() {
            int c = skipWhileSpace();
            long sgn = 1;
            if (c == '-') {
                sgn = -1;
                c = next();
            }
            long res = 0;
            do {
                if (c < '0' || c > '9') {
                    throw new InputMismatchException();
                }
                res *= 10;
                res += c - '0';
                c = next();
            } while (!isSpaceChar(c));
            return res * sgn;
        }

        public double nextDouble() {
            return Double.valueOf(nextToken());
        }

        int skipWhileSpace() {
            int c = next();
            while (isSpaceChar(c)) {
                c = next();
            }
            return c;
        }

        boolean isSpaceChar(int c) {
            return c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == -1;
        }
    }
}

Submission Info

Submission Time
Task H - タイル張り
User hamadu
Language Java8 (OpenJDK 1.8.0)
Score 1000
Code Size 10873 Byte
Status
Exec Time 349 ms
Memory 63964 KB

Test Cases

Set Name Score / Max Score Test Cases
Sample 0 / 0 s1.txt, s2.txt, s3.txt
All 1000 / 1000 01.txt, 02.txt, 03.txt, 04.txt, 05.txt, 06.txt, 07.txt, 08.txt, 09.txt, 10.txt, 11.txt, 12.txt, 13.txt, 14.txt, 15.txt, 16.txt, 17.txt, 18.txt, 19.txt, 20.txt, 21.txt, 22.txt, s1.txt, s2.txt, s3.txt
Case Name Status Exec Time Memory
01.txt 283 ms 62812 KB
02.txt 232 ms 60228 KB
03.txt 312 ms 62136 KB
04.txt 322 ms 63964 KB
05.txt 312 ms 59860 KB
06.txt 256 ms 61660 KB
07.txt 108 ms 23252 KB
08.txt 123 ms 24272 KB
09.txt 123 ms 26068 KB
10.txt 76 ms 21204 KB
11.txt 75 ms 19156 KB
12.txt 70 ms 21332 KB
13.txt 327 ms 60100 KB
14.txt 72 ms 23380 KB
15.txt 70 ms 18900 KB
16.txt 68 ms 18004 KB
17.txt 236 ms 56580 KB
18.txt 72 ms 19028 KB
19.txt 71 ms 19540 KB
20.txt 349 ms 62268 KB
21.txt 270 ms 62132 KB
22.txt 326 ms 61836 KB
s1.txt 71 ms 21460 KB
s2.txt 72 ms 19540 KB
s3.txt 281 ms 61496 KB