はじめに
TFUGベイズ分科会にてPPLについて話しました。改めてPPLを複数比較してみたことで、一層NumPyrpの書きやすさにほれぼれとしました。
現状、PPLをまとめると
- 通常利用:Stan
- より低レベルを触れる研究用途:TensorFlow Probability
- 深層生成モデル及びベイズニューラルネットの変分推論 : Pyro
- 上記及び高速なMCMC : NumPyro
という具合です。実際、速度やインターフェースの書きやすさを見るとNumPyroが個人的には抜けているのですが、一方でバックエンドがJaxであるという点が、まだまだ広がりを見せない要因にもなっているかと思われます(もしも、PyTorchよりもJaxの方がメジャーだったとすれば、Pyroを選ぶ理由がほとんど無い程。と言っても、Jaxのfunctionalな書き方はかなり慣れを要するし、これまでのPythonの使い方をフル活用したDLライブラリとはひと味違う)。
そんなこんなで、おそらく敬遠されているであろうNumPyroがこんなにも簡単に使えるのだということを見ていただこうと記事にしました。今回の題材は、TFUGにてTensorFlow Probabilityの基礎と応用の発表を多いに参考にしております。
NumPyro基本
ライブラリの準備
colabを利用する場合は冒頭に下記のコードによってライブラリを入れる必要があります。 numpyroのインストール時に必要なJaxも揃います。funsorは現在Pyroの低階層に入ることが想定されているTensorラッピングモジュールで、NumPyroの場合はfunsorがあることで、離散確率変数の取り扱いをサポートしてもらえます。
コード上では直接funsorを使うことはありません。
!pip install numpyro !pip install funsor
ライブラリのインポートは下記です。
import jax.numpy as jnp from jax import random import numpyro import numpyro.distributions as dist import numpyro.infer as infer from numpyro import sample, plate, handlers import funsor import matplotlib.pyplot as plt import seaborn as sns plt.style.use("seaborn")
確率分布
まずは確率分布の取り扱いを見ましょう。基本的には torch.distributions
モジュールと同じ使い方であり、torch.distributions
モジュールは TensorFlow Probabilityの設計をかなり参考にしているため、実質この手のPPLを利用する人には全く違和感のないものになっています。
normal = dist.Normal(loc=0., scale=1.)
上記のコードによって、平均 $0$ 標準偏差 $1$ の正規分布のインスタンスが作成できます。次に、このインスタンスのメソッドを軽く見ていきましょう。
samples = normal.sample(random.PRNGKey(0), (1000,)) sns.distplot(samples)
上記のコードによって正規分布から1000個の値をドローできます。結果である samples
は shape = (1000, )
の形をしており、各値が独立に生成されることになります。
また、お決まりのように log_prob
メソッドが準備されており
normal.log_prob(0.)
といった具合で、正規分布から見た $0$ の生成確率の対数が計算できます。ここに先程の samples
を入れてやると、1000個それぞれの対数確率が格納された jax.numpy.array
が返ってきます。ここまでを駆使すれば、問題なく最尤推定などが実行できるはずです。
transoforms
モジュール (tfp.bijector
相当)
確率分布の変換を行うためのモジュール numpyro.distributions.transforms
も重要です。
下記のように dist.TransformedDistributions
に対して base_distribution
として正規分布を、transforms
として dist.transforms.ExpTransform
のインスタンスをそれぞれ渡してみましょう。
log_normal = dist.TransformedDistribution( base_distribution=dist.Normal(0., 0.5), transforms=dist.transforms.ExpTransform() ) samples = log_normal.sample(random.PRNGKey(0), (1000,)) sns.distplot(samples)
これによって、正規分布から生起した確率変数を exp
によって変換して得られる $Y = \exp(X)$ を確率変数の実現値として生成するような分布を作ることができます。実際のところ、これにはすでに対数正規分布という固有名詞が与えられており、わざわざ作る必要はなく dist.LogNormal
によって整備されています。
この transforms
モジュールは複雑な変換を連鎖させることも可能で、固有名詞のついていない独自の分布を構成することが可能です。
transforms_list = [ dist.transforms.ExpTransform(), dist.transforms.AffineTransform(loc=-1, scale=2.0), dist.transforms.PowerTransform(exponent=0.8), dist.transforms.SigmoidTransform() ] flow_dist = dist.TransformedDistribution( base_distribution=dist.Normal(0., 0.5), transforms=transforms_list ) samples = flow_dist.sample(random.PRNGKey(0), (1000,)) sns.distplot(samples)
これによって、乱数生成の元が単純な確率分布だったとしても、複雑な分布を表現できます。これがいわゆるVAEのFlowモデルの肝になっているわけですね。
これらを用いれば、TF Probabilityチックなことを自前でやることができるわけですが、NumPyroにはTFPのように確率分布クラスを複数組み合わせて同時分布を作り、その同時分布の対数確率を評価することで推論を実行する…という手続きは隠蔽されています。
実際にはこれらの分布から値をサンプリングする流れをpython関数で記述することで、同時分布を表現し、その流れをトレースできるクラスによって様々な操作が可能となっています。
変化点検知
変化点検知を例に、モデルの作成方法を見ます。
データ
データはベタ書きしてしまいます。これは各年の災害の発生件数です。
disaster_data = jnp.array([ 4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6, 3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5, 2, 2, 3, 4, 2, 1, 3, 2, 2, 1, 1, 1, 1, 3, 0, 0, 1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2, 3, 3, 1, 1, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1]) years = jnp.arange(1851, 1962) plt.figure(figsize=(8, 4)) plt.plot(years, disaster_data, "o", color='b', markersize=6, alpha=0.5) plt.ylabel("Disaster count") plt.xlabel("Year")
モデル
これに対して下記のようなモデルを書きます。
def model(disaster_data=None, N=len(years)): e = sample("e", dist.Exponential(rate=1.0)) l = sample("l", dist.Exponential(rate=1.0)) s = sample("s", dist.Uniform(low=0, high=len(years))) with plate("data_plate", N): rate = numpyro.deterministic( "rate", jnp.where(jnp.arange(len(years)) < s, e, l) ) d_t = sample( "d_t", dist.Poisson(rate=rate), obs=disaster_data )
ご覧の通り、分布から生成された変数を使って更に別の分布から値を生成するという流れを具体的に書くことでモデリングが実施できます。with plate
コンテキストは非常に強力で、これの外にあるグローバルな確率変数を使って、コンテキスト内の確率変数は、条件付き独立で値が繰り返し生成されます。同等の動きのスニペットとして
rate_list = [] d_t_list = [] for n in range(N): rate = numpyro.deterministic( "rate", jnp.where(jnp.arange(len(years)) < s, e, l) ) rate_list.append(rate) d_t = sample( "d_t", dist.Poisson(rate=rate), obs=disaster_data[n] ) d_t_list.append(d_t) rate = jnp.stack(rate_list) d_t_list = jnp.stack(d_t_list)
という表記が考えられます。こちらはpythonの処理そのままで直感的ですが、この処理をvectorizeしてしまうのが plate
の正体です。どちらが書きやすいかは好みですが、グラフィカルモデル表記の対応も加味すれば、圧倒的に plate
の方が見やすいはずです。
また、NumPyroでは観測変数に相当する確率変数には sample(name, dist, obs=obs_data)
という書き方をします。通常 obs=None
がデフォルト引数で設定されており、obs=None
の場合は確率変数を分布から生起させ、obs
が与えられていれば、それを生成した体でデータ生成を進めてくれるということになります。そういうわけで、作ったモデルから本当にランダムに値を生成したければ obs
を書いた引数に対して None
を与えてやれば良いのです(そういうわけでモデル model
の引数がそうなっているのだ)。
事前分布からのサンプリングでモデルの動作確認
python関数で書かれたモデルは返り値をもたせても持たせなくても構いません。
handlers.seed
モジュールによって乱数シードを設定してあげて、乱数シードが設定されたモデルを handlers.trace
関数で囲えば、設定したシードによって確率変数のサンプリングを実施したインスタンスを返してくれます。
traced_model = handlers.trace(handlers.seed(model, 2020)) prior_sample = traced_model.get_trace()["d_t"]["value"] plt.plot(years, prior_sample, "o")
traced_model
は、乱数シード設定→サンプリングの流れを終えたNumPyro/Pyro独自のクラスインスタンスで、内部変数に model
内部で定義した sample()
関数によって呼び出された変数すべてを辞書型で所持しています。このあたりは、若干、そういう実装、そういう使い方に慣れる必要があり、生のTensorをひたすら扱うTFPよりも一見ややこしいのですが、逆にそれらを自身で管理する必要が無く、モジュール側に色々と任せることができて非常に楽です。
MCMC推論
推論はたった3行で実施できます。infer.NUTS
を利用して推論を実施します。この際に adapt_step_size
を True
にしておくと、ハイパーパラメータである更新幅をバーンイン期間(サンプルをお試しで発生させ、安定するまで捨てる期間)に自動でいい感じに調整してくれます。
mcmc.run()
の引数は第一引数に乱数シードを作るrandom.PRNGKey
のインスタンスを与え、残りの引数は model
で設定した引数を与えることになります。
kernel = infer.NUTS(model, adapt_step_size=True) mcmc = infer.MCMC(kernel, 3000, 10000, num_chains=3) mcmc.run(random.PRNGKey(0), disaster_data, len(years))
結果確認
結果の確認は基本的な統計量を
mcmc.print_summary() mean std median 5.0% 95.0% n_eff r_hat e 3.07 0.28 3.06 2.60 3.52 1941.29 1.00 l 0.94 0.12 0.93 0.75 1.14 1371.96 1.00 s 39.43 2.49 39.64 35.00 42.45 1514.09 1.00 Number of divergences: 0
と確認できます。
mcmc
インスタンスはArviZと呼ばれるベイズ推論の結果可視化ライブラリのデータ構造に合わせたものとなっていますので、可視化には積極的にArviZを使えばよいのですが、今回はネイティブのNumPyroを扱うために直接結果を取り出してみます。
とは言っても get_samples()
メソッドで結果を一括で取り出すことができ、これは辞書型となっております。key
は model
の中で sample
関数によって設定した名前がそのまま使われており、
mcmc_samples = mcmc.get_samples() e = mcmc_samples["e"] l = mcmc_samples["l"] s = mcmc_samples["s"] r = mcmc_samples["rate"]
と各々のデータを取り出せます。
更に、これらの推論結果を用いて、パラメータを事後分布による周辺化消去を実施し、予測分布を構築することも簡単に行えます。
predictive = infer.Predictive(model, mcmc_samples) predict_samples = predictive(random.PRNGKey(0), disaster_data=None, N=len(years)) r_predict_samples = predict_samples["rate"] d_predict_samples = predict_samples["d_t"]
としてやることで、パラメータ以外の確率変数を予測出力させることができます。ここで predictive
インスタンスの引数はお決まりの乱数シード生成器と、model
で設定した引数を続けるのですが、disaster_data=None
としてやることで、観測がありませんということ(すなわちモデル自身に出力させる)事が可能となります。
さて、結果を可視化してみましょう。95%予測区間と、ポアソン分布の平均パラメータの95%信用区間を一緒に表示してみましょう。
d_mean = d_predict_samples.mean(axis=0) r_mean = r_predict_samples.mean(axis=0) d_low, d_high = jnp.percentile(d_predict_samples, [2.5, 97.5], axis=0) r_low, r_high = jnp.percentile(r_predict_samples, [2.5, 97.5], axis=0) plt.plot(years, disaster_data, "o") plt.plot(years, d_mean, "r") plt.plot(years, r_mean, "g") plt.fill_between(years, d_low, d_high, "r", alpha=0.3) plt.fill_between(years, r_low, r_high, "g", alpha=0.3)
こうして変化が起きた年の推論が行えました。