HELLO CYBERNETICS

深層学習、機械学習、強化学習、信号処理、制御工学、量子計算などをテーマに扱っていきます

NumPyroの基本を変化点検知で見る

 

 

follow us in feedly

はじめに

TFUGベイズ分科会にてPPLについて話しました。改めてPPLを複数比較してみたことで、一層NumPyrpの書きやすさにほれぼれとしました。

www.hellocybernetics.tech

現状、PPLをまとめると

  • 通常利用:Stan
  • より低レベルを触れる研究用途:TensorFlow Probability
  • 深層生成モデル及びベイズニューラルネットの変分推論 : Pyro
  • 上記及び高速なMCMC : NumPyro

という具合です。実際、速度やインターフェースの書きやすさを見るとNumPyroが個人的には抜けているのですが、一方でバックエンドがJaxであるという点が、まだまだ広がりを見せない要因にもなっているかと思われます(もしも、PyTorchよりもJaxの方がメジャーだったとすれば、Pyroを選ぶ理由がほとんど無い程。と言っても、Jaxのfunctionalな書き方はかなり慣れを要するし、これまでのPythonの使い方をフル活用したDLライブラリとはひと味違う)。

そんなこんなで、おそらく敬遠されているであろうNumPyroがこんなにも簡単に使えるのだということを見ていただこうと記事にしました。今回の題材は、TFUGにてTensorFlow Probabilityの基礎と応用の発表を多いに参考にしております。

NumPyro基本

ライブラリの準備

colabを利用する場合は冒頭に下記のコードによってライブラリを入れる必要があります。 numpyroのインストール時に必要なJaxも揃います。funsorは現在Pyroの低階層に入ることが想定されているTensorラッピングモジュールで、NumPyroの場合はfunsorがあることで、離散確率変数の取り扱いをサポートしてもらえます。

コード上では直接funsorを使うことはありません。

!pip install numpyro
!pip install funsor

ライブラリのインポートは下記です。

import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
import numpyro.infer as infer
from numpyro import sample, plate, handlers
import funsor

import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use("seaborn")

確率分布

まずは確率分布の取り扱いを見ましょう。基本的には torch.distributions モジュールと同じ使い方であり、torch.distributions モジュールは TensorFlow Probabilityの設計をかなり参考にしているため、実質この手のPPLを利用する人には全く違和感のないものになっています。

normal = dist.Normal(loc=0., scale=1.)

上記のコードによって、平均 $0$ 標準偏差 $1$ の正規分布のインスタンスが作成できます。次に、このインスタンスのメソッドを軽く見ていきましょう。

samples = normal.sample(random.PRNGKey(0), (1000,))
sns.distplot(samples)

上記のコードによって正規分布から1000個の値をドローできます。結果である samplesshape = (1000, ) の形をしており、各値が独立に生成されることになります。

f:id:s0sem0y:20200909004557p:plain

また、お決まりのように log_prob メソッドが準備されており

normal.log_prob(0.)

といった具合で、正規分布から見た $0$ の生成確率の対数が計算できます。ここに先程の samples を入れてやると、1000個それぞれの対数確率が格納された jax.numpy.array が返ってきます。ここまでを駆使すれば、問題なく最尤推定などが実行できるはずです。

transoforms モジュール (tfp.bijector相当)

確率分布の変換を行うためのモジュール numpyro.distributions.transforms も重要です。 下記のように dist.TransformedDistributions に対して base_distribution として正規分布を、transforms として dist.transforms.ExpTransform のインスタンスをそれぞれ渡してみましょう。

log_normal = dist.TransformedDistribution(
    base_distribution=dist.Normal(0., 0.5),
    transforms=dist.transforms.ExpTransform()
)

samples = log_normal.sample(random.PRNGKey(0), (1000,))
sns.distplot(samples)

f:id:s0sem0y:20200909005129p:plain

これによって、正規分布から生起した確率変数を exp によって変換して得られる $Y = \exp(X)$ を確率変数の実現値として生成するような分布を作ることができます。実際のところ、これにはすでに対数正規分布という固有名詞が与えられており、わざわざ作る必要はなく dist.LogNormal によって整備されています。

この transforms モジュールは複雑な変換を連鎖させることも可能で、固有名詞のついていない独自の分布を構成することが可能です。

transforms_list = [
    dist.transforms.ExpTransform(),
    dist.transforms.AffineTransform(loc=-1, scale=2.0),
    dist.transforms.PowerTransform(exponent=0.8),
    dist.transforms.SigmoidTransform()
]

flow_dist = dist.TransformedDistribution(
    base_distribution=dist.Normal(0., 0.5),
    transforms=transforms_list
)

samples = flow_dist.sample(random.PRNGKey(0), (1000,))
sns.distplot(samples)

f:id:s0sem0y:20200909005458p:plain

これによって、乱数生成の元が単純な確率分布だったとしても、複雑な分布を表現できます。これがいわゆるVAEのFlowモデルの肝になっているわけですね。

これらを用いれば、TF Probabilityチックなことを自前でやることができるわけですが、NumPyroにはTFPのように確率分布クラスを複数組み合わせて同時分布を作り、その同時分布の対数確率を評価することで推論を実行する…という手続きは隠蔽されています。

実際にはこれらの分布から値をサンプリングする流れをpython関数で記述することで、同時分布を表現し、その流れをトレースできるクラスによって様々な操作が可能となっています。

変化点検知

変化点検知を例に、モデルの作成方法を見ます。

データ

データはベタ書きしてしまいます。これは各年の災害の発生件数です。

disaster_data = jnp.array([ 4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
                            3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
                            2, 2, 3, 4, 2, 1, 3, 2, 2, 1, 1, 1, 1, 3, 0, 0,
                            1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
                            0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
                            3, 3, 1, 1, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
                            0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1])
years = jnp.arange(1851, 1962)

plt.figure(figsize=(8, 4))
plt.plot(years, disaster_data, "o", color='b', markersize=6, alpha=0.5)
plt.ylabel("Disaster count")
plt.xlabel("Year")

f:id:s0sem0y:20200909010220p:plain

モデル

これに対して下記のようなモデルを書きます。

def model(disaster_data=None, N=len(years)):
    e = sample("e", dist.Exponential(rate=1.0))
    l = sample("l", dist.Exponential(rate=1.0))
    s = sample("s", dist.Uniform(low=0, high=len(years)))
    with plate("data_plate", N):
        rate = numpyro.deterministic(
            "rate", jnp.where(jnp.arange(len(years)) < s, e, l)
        )
        d_t = sample(
            "d_t",
            dist.Poisson(rate=rate),
            obs=disaster_data
        )

ご覧の通り、分布から生成された変数を使って更に別の分布から値を生成するという流れを具体的に書くことでモデリングが実施できます。with plate コンテキストは非常に強力で、これの外にあるグローバルな確率変数を使って、コンテキスト内の確率変数は、条件付き独立で値が繰り返し生成されます。同等の動きのスニペットとして

    rate_list = []
    d_t_list = []
    for n in range(N):
        rate = numpyro.deterministic(
            "rate", jnp.where(jnp.arange(len(years)) < s, e, l)
        )
        rate_list.append(rate)
        d_t = sample(
            "d_t",
            dist.Poisson(rate=rate),
            obs=disaster_data[n]
        )
       d_t_list.append(d_t)
    rate = jnp.stack(rate_list)
    d_t_list = jnp.stack(d_t_list)

という表記が考えられます。こちらはpythonの処理そのままで直感的ですが、この処理をvectorizeしてしまうのが plate の正体です。どちらが書きやすいかは好みですが、グラフィカルモデル表記の対応も加味すれば、圧倒的に plate の方が見やすいはずです。

また、NumPyroでは観測変数に相当する確率変数には sample(name, dist, obs=obs_data) という書き方をします。通常 obs=Noneがデフォルト引数で設定されており、obs=None の場合は確率変数を分布から生起させ、obs が与えられていれば、それを生成した体でデータ生成を進めてくれるということになります。そういうわけで、作ったモデルから本当にランダムに値を生成したければ obs を書いた引数に対して None を与えてやれば良いのです(そういうわけでモデル model の引数がそうなっているのだ)。

事前分布からのサンプリングでモデルの動作確認

python関数で書かれたモデルは返り値をもたせても持たせなくても構いません。 handlers.seed モジュールによって乱数シードを設定してあげて、乱数シードが設定されたモデルを handlers.trace 関数で囲えば、設定したシードによって確率変数のサンプリングを実施したインスタンスを返してくれます。

traced_model = handlers.trace(handlers.seed(model, 2020))
prior_sample = traced_model.get_trace()["d_t"]["value"]

plt.plot(years, prior_sample, "o")

f:id:s0sem0y:20200909011353p:plain

traced_model は、乱数シード設定→サンプリングの流れを終えたNumPyro/Pyro独自のクラスインスタンスで、内部変数に model 内部で定義した sample() 関数によって呼び出された変数すべてを辞書型で所持しています。このあたりは、若干、そういう実装、そういう使い方に慣れる必要があり、生のTensorをひたすら扱うTFPよりも一見ややこしいのですが、逆にそれらを自身で管理する必要が無く、モジュール側に色々と任せることができて非常に楽です。

MCMC推論

推論はたった3行で実施できます。infer.NUTS を利用して推論を実施します。この際に adapt_step_sizeTrue にしておくと、ハイパーパラメータである更新幅をバーンイン期間(サンプルをお試しで発生させ、安定するまで捨てる期間)に自動でいい感じに調整してくれます。

mcmc.run()の引数は第一引数に乱数シードを作るrandom.PRNGKeyのインスタンスを与え、残りの引数は model で設定した引数を与えることになります。

kernel = infer.NUTS(model, adapt_step_size=True)
mcmc = infer.MCMC(kernel, 3000, 10000, num_chains=3)
mcmc.run(random.PRNGKey(0), disaster_data, len(years))

結果確認

結果の確認は基本的な統計量を

mcmc.print_summary()

                mean       std    median      5.0%     95.0%     n_eff     r_hat
         e      3.07      0.28      3.06      2.60      3.52   1941.29      1.00
         l      0.94      0.12      0.93      0.75      1.14   1371.96      1.00
         s     39.43      2.49     39.64     35.00     42.45   1514.09      1.00

Number of divergences: 0

と確認できます。

mcmc インスタンスはArviZと呼ばれるベイズ推論の結果可視化ライブラリのデータ構造に合わせたものとなっていますので、可視化には積極的にArviZを使えばよいのですが、今回はネイティブのNumPyroを扱うために直接結果を取り出してみます。

とは言っても get_samples() メソッドで結果を一括で取り出すことができ、これは辞書型となっております。keymodel の中で sample 関数によって設定した名前がそのまま使われており、

mcmc_samples = mcmc.get_samples()
e = mcmc_samples["e"]
l = mcmc_samples["l"]
s = mcmc_samples["s"]
r = mcmc_samples["rate"]

と各々のデータを取り出せます。

更に、これらの推論結果を用いて、パラメータを事後分布による周辺化消去を実施し、予測分布を構築することも簡単に行えます。

predictive = infer.Predictive(model, mcmc_samples)
predict_samples = predictive(random.PRNGKey(0), disaster_data=None, N=len(years))
r_predict_samples = predict_samples["rate"]
d_predict_samples = predict_samples["d_t"]

としてやることで、パラメータ以外の確率変数を予測出力させることができます。ここで predictive インスタンスの引数はお決まりの乱数シード生成器と、model で設定した引数を続けるのですが、disaster_data=None としてやることで、観測がありませんということ(すなわちモデル自身に出力させる)事が可能となります。

さて、結果を可視化してみましょう。95%予測区間と、ポアソン分布の平均パラメータの95%信用区間を一緒に表示してみましょう。

d_mean = d_predict_samples.mean(axis=0)
r_mean = r_predict_samples.mean(axis=0)
d_low, d_high = jnp.percentile(d_predict_samples, [2.5, 97.5], axis=0)
r_low, r_high = jnp.percentile(r_predict_samples, [2.5, 97.5], axis=0)

plt.plot(years, disaster_data, "o")
plt.plot(years, d_mean, "r")
plt.plot(years, r_mean, "g")
plt.fill_between(years, d_low, d_high, "r", alpha=0.3)
plt.fill_between(years, r_low, r_high, "g", alpha=0.3)

f:id:s0sem0y:20200909012822p:plain

こうして変化が起きた年の推論が行えました。