2020/04/29

AtCoder 用 Java テンプレート (I/O編)

Twitter でなぜか競技プログラミングクラスタにつながったようなので、ちょっと AtCoder をやってみたり。競技プログラミングは初めてです。あ、大昔に SuperCon '95 や PSC 98 に出ましたが、TopCoder 系のは初めて。

ぶっつけ本番でやってみましたが、けっこう難しいし、Java で参戦したんですが、I/O 系のコードがめんどくさい。ちょっとテンプレートを用意してみることにしました。

テンプレートコード

public final class MainTemplate {
  private static void execute(final IO io) throws Exception {
  }

  public static void main(final String[] args) throws Exception {
    try (IO io = new IO(System.in, System.out)) {
      execute(io);
    }
  }

  public static final class IO implements AutoCloseable {
    private final java.io.InputStream in;
    private final java.io.BufferedOutputStream out;
    private static final int BUFFER_SIZE = 1 << 16;
    private final byte[] buf = new byte[BUFFER_SIZE];
    private int pos = 0;
    private int end = 0;
    public IO(final java.io.InputStream in, final java.io.OutputStream out) throws java.io.IOException {
      this.in = in;
      this.out = new java.io.BufferedOutputStream(out);
    }

    @Override
    public void close() throws java.io.IOException {
      out.close();
    }

    private int fetchByte() throws java.io.IOException {
      if (pos >= end) {
        pos = 0;
        end = in.read(buf);
        if (end <= 0) {
          return -1;
        }
      }
      return buf[pos++];
    }

    public String nextString() throws java.io.IOException {
      int c;
      for (c = fetchByte(); c <= ' '; c = fetchByte()) {
      }
      final StringBuilder sb = new StringBuilder();
      for (; c > ' '; c = fetchByte()) {
        sb.append((char) c);
      }
      return sb.toString();
    }

    public String[] nextString(final int n) throws java.io.IOException {
      final String[] result = new String[n];
      for (int i = 0; i < n; i++) {
        result[i] = nextString();
      }
      return result;
    }

    public int nextInt() throws java.io.IOException {
      int val = 0;
      int c;
      for (c = fetchByte(); c <= ' '; c = fetchByte()) {
      }
      boolean neg = c == '-';
      if (c == '-' || c == '+') {
        c = fetchByte();
      }
      for (; c >= '0' && c <= '9'; c = fetchByte()) {
        val = (val << 3) + (val << 1) + (c & 15);
      }
      return neg ? -val : val;
    }

    public int[] nextInt(final int n) throws java.io.IOException {
      final int[] result = new int[n];
      for (int i = 0; i < n; i++) {
        result[i] = nextInt();
      }
      return result;
    }

    public long nextLong() throws java.io.IOException {
      long val = 0;
      int c;
      for (c = fetchByte(); c <= ' '; c = fetchByte()) {
      }
      boolean neg = c == '-';
      if (c == '-' || c == '+') {
        c = fetchByte();
      }
      for (; c >= '0' && c <= '9'; c = fetchByte()) {
        val = (val << 3) + (val << 1) + (c & 15);
      }
      return neg ? -val : val;
    }

    public long[] nextLong(final int n) throws java.io.IOException {
      final long[] result = new long[n];
      for (int i = 0; i < n; i++) {
        result[i] = nextLong();
      }
      return result;
    }

    public void print(final Object a) throws java.io.IOException {
      out.write(a.toString().getBytes());
    }

    private static final byte[] SP = new byte[] { 0x20 };
    public void printsp(final Object a) throws java.io.IOException {
      out.write(a.toString().getBytes());
      out.write(SP);
    }

    private static final byte[] CRLF = System.lineSeparator().getBytes();
    public void println() throws java.io.IOException {
      out.write(CRLF);
    }

    public void println(final Object a) throws java.io.IOException {
      out.write(a.toString().getBytes());
      out.write(CRLF);
    }

    public void printaln(final int[] a) throws java.io.IOException {
      for (int i = 0, n = a.length; i < n; i++) {
        out.write(Integer.toString(a[i]).getBytes());
        out.write(CRLF);
      }
    }

    public void printasp(final int[] a) throws java.io.IOException {
      for (int i = 0, n = a.length; i < n; i++) {
        out.write(Integer.toString(a[i]).getBytes());
        out.write(SP);
      }
    }

    public void printaln(final long[] a) throws java.io.IOException {
      for (int i = 0, n = a.length; i < n; i++) {
        out.write(Long.toString(a[i]).getBytes());
        out.write(CRLF);
      }
    }

    public void printasp(final long[] a) throws java.io.IOException {
      for (int i = 0, n = a.length; i < n; i++) {
        out.write(Long.toString(a[i]).getBytes());
        out.write(SP);
      }
    }

    public void printaln(final Object[] a) throws java.io.IOException {
      for (int i = 0, n = a.length; i < n; i++) {
        out.write(a[i].toString().getBytes());
        out.write(CRLF);
      }
    }

    public void printasp(final Object[] a) throws java.io.IOException {
      for (int i = 0, n = a.length; i < n; i++) {
        out.write(a[i].toString().getBytes());
        out.write(SP);
      }
    }

    public void flush() throws java.io.IOException {
      out.flush();
    }
  }
}

説明

IO という nested static クラスを作り、これを介して標準入出力、つまり問題のパラメータの読み込みと、結果の出力をやります。
main() の中で IO インスタンスを用意して、それを execute() に渡すようにしています。終わったら IO を close() しますが、ここで確実に出力を flush() するようにしています。IO は AutoCloseable にして、try-with-resources 構文が使えるようにしてあります。
入力は、BufferedReader にせず、InputStream のまま取り扱い、自前のバッファを使って buffered read を実現しています (28〜37行目)。

ほとんど整数問題なので、整数の読み取りを高速に行えるよう、nextInt() を用意しています。1文字ずつ読み出し、数字が続く限り10倍して数値化したものを足す、という繰り返しなのですが、ちょっと工夫があり、10倍は8倍(3ビットシフト)と2倍(1ビットシフト)の加算としています。数値化は文字コードに対して 15 (2進数で1111) とのビット論理和を取ることで実行しています。'0' の文字コードは2進数で110000、'1' は 110001、...、'9' は 111001 なので、この方法がうまくいきます (68行目など)。

出力の方は、BufferedOutputStream を使っています。System.out は頻繁に flush() が発生してちょっと効率が悪いので、手前でバッファするよう BufferedOutputStream をかぶせています。non-ASCII 文字の出力は無さそうなので、BufferedWriter にはしていません。