2020/04/30

AtCoder 用 Java テンプレート (剰余演算編)

I/O編に引き続いて、剰余演算編。
AtCoder では、素数を法とした剰余類環上で答えを出す問題がちょいちょい出るようです。剰余類環の性質を知っているかどうかを試したいというよりは、巨大数計算で困らせたくないという親切心のようなので、よく使う性質はコードテンプレートに入れておこうかと思い、作ってみました。

コードテンプレート

今回は、ABC156 のD問題を解いてみたコードと共に紹介します。class IO のところは今回は省略しますのでI/O編をご覧ください。

public final class MainABC156D {
  private static void execute(final IO io) throws Exception {
    final int n = io.nextInt();
    final int a = io.nextInt();
    final int b = io.nextInt();
    MC c = new MC(1000000007);
    long x1 = c.power(2, n) - 1;
    long x2 = c.combination(n, a);
    long x3 = c.combination(n, b);
    io.println(c.add(x1, -(x2 + x3)));
  }

  public static void main(final String[] args) throws Exception {
    try (IO io = new IO(System.in, System.out)) {
      execute(io);
    }
  }

  public static final class MC {
    private final int mod;
    public MC(final int mod) {
      this.mod = mod;
    }

    public long mod(long x) {
      x %= mod;
      if (x < 0) {
        x += mod;
      }
      return x;
    }

    public long add(final long a, final long b) {
      return mod(a + b);
    }

    public long mul(final long a, final long b) {
      return mod(a * b);
    }

    public long div(final long numerator, final long denominator) {
      return mod(numerator * inverse(denominator));
    }

    public long power(long base, long exp) {
      long ret = 1;
      base %= mod;
      while (exp > 0) {
        if ((exp & 1) == 1) {
          ret = mul(ret, base);
        }
        base = mul(base, base);
        exp >>= 1;
      }
      return ret;
    }

    public long inverse(final long x) {
      return power(x, mod - 2);
    }

    public long factorial(final int n) {
      return product(1, n);
    }

    public long product(final int start, final int end) {
      long result = 1;
      for (int i = start; i <= end; i++) {
        result *= i;
        result %= mod;
      }
      return result;
    }

    public long combination(final int n, int r) {
      if (r > n) {
        return 0;
      }
      return div(product(n - r + 1, n), factorial(r));
    }
  }

  public static final class IO implements AutoCloseable {
    // 省略
  }
}

解き方としては、n個から1個以上を選ぶ選び方の総和なので、$ \sum_{r=1}^{n} {}_n \mathrm{C} _r = 2^n - 1 $ がまず基本にあり、そこから $ {}_n \mathrm{C} _a $ と $ {}_n \mathrm{C} _b $ を引いてあげれば良いという発想です。
7〜10行目を見ていただくと、剰余のことを気にせず、解き方通りに書けていることが分かるかと思います。これが今回の狙いです。

解説

19行目から始まる MC という nested クラスに剰余類環上の計算をさせています。コンストラクタで法を渡すと、それに沿った計算をしてくれるインスタンスが出来ます。今回は $ 10^9+7 $ が法なので、そういうインスタンスを作っています (6行目)。
MC には、今までざっと解いてみた中で必要そうなメソッドを用意してみました。

mod()

まず、剰余を取る mod() です。普通に % 演算子でいいのですが、負数に対して % 演算子を適用すると、負数が返ってくるのが Java の仕様です。これでは剰余類環上に乗らないので、負数になったら法を加算してあげています。
負数にはならないと分かっているときは無駄になる計算量ですが、オーダーに影響するわけではないので、ミスらないよう、深く考えずに mod() を使えばいいかなと思っています。

add(), mul()

足したり掛けたりして mod() です。減算は add() にマイナスして渡せばいいので、作りません。

power()

べき乗演算です。$ a^n $ に対して、愚直に $n$ 回乗算をすると $O(n)$ ですが、バイナリ法で $O(\log{n})$ にします。$n$ が奇数なら $a \cdot (a^2)^{n/2}$、$n$ が偶数なら $(a^2)^{n/2}$ として、$a$ を $a^2$、$n$ を $n/2$ に置き換えながら繰り返します。

inverse()

除算に行く前に、逆元を考える必要があります。剰余類環上で考える逆数のようなもので、$a$ の逆元 $a^{-1}$ が計算できれば、$a/b$ にあたるものが $a \cdot b^{-1}$ として計算できます。
$a$ の逆元は法 $p$ が素数で、$a$ と $p$ が互いに素なら、$ a^{p-2} $ になります (フェルマーの小定理です)。したがって power() で計算できます。ちなみに $a$ が $p$ の倍数のときはだめです。0除算に相当します。

div()

inverse() ができたので、前述の通りです。

product()

$ {}_n \mathrm{C} _r = \frac{n!}{r!(n-r)!} $ を計算したいのですが、階乗をいっぱい計算するのは大変なので、$n!$ と $(n-r)!$ は相殺して、$ (\prod_{x=n-r+1}^{n} x) \cdot (r!)^{-1} $ にします。連続する整数の乗算をする関数として product() をまず作ります。

factorial()

階乗 (factorial) は、product() を起点1で計算するだけ。

combination()

product(), factorial(), div() があればもう簡単。

これでもう剰余計算が出てきても大丈夫!