HELLO CYBERNETICS

深層学習、機械学習、強化学習、信号処理、制御工学、量子計算などをテーマに扱っていきます

Pyro on PyTorchでベイズ予測分布(MAP推定、変分推論、MCMC)

 

 

follow us in feedly

はじめに

Pyroが思ってた以上に便利になっており、正式リリース後のPyro1.X系は謎のメソッドを連発しなくてもデータ解析ができるようになっていました。 それに伴い、既に推奨されていない書き方などもあるため、今回は正式リリース後の書き方でベイズ予測分布を出すコードを見ていきます。

基本事項として

www.hellocybernetics.tech

www.hellocybernetics.tech

などの内容が理解できていると良いです。

前提として下記のインポート文を実施していることとします。

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import pyro
import pyro.distributions as dist
import pyro.infer as infer

plt.style.use("seaborn")

Pyroおさらい

データ

まずはおさらいとしてMAP推定を見ていきます。 PyroのインターフェースではMAP推定はデルタ変分分布を用いた変分推論、あるいは事後分布のラプラス近似によって実施されるため、MAP推定のやり方がわかっていれば、変分推論の書き方も自ずとわかります。

N = 20

x_data = np.random.rand(N).reshape(-1, 1)
y_data = 1 + 4*x_data + 0.1*np.random.randn(N).reshape([-1, 1])

x_data_th = torch.as_tensor(x_data) # should be torch.Tensor for pyro  
y_data_th = torch.as_tensor(y_data) 

plt.plot(x_data, y_data, "o")

f:id:s0sem0y:20191208174403p:plain

モデル

$$ \begin{align} w _ 0 &\sim {\rm Normal} (0, 10)\\ w _ 1 &\sim {\rm Normal} (0, 10) \end{align} $$

という事前分布と回帰モデル

$$ y \sim {\rm Normal}(y\mid w _ 0 + w _ 1 x, 0.5) $$

によって、同時分布を

$$ p(y , w _ 0, w _ 1 \mid x) = {\rm Normal} (w _ 0\mid 0, 10) {\rm Normal} (w _ 1 \mid 0, 10) {\rm Normal}(y\mid w _ 0 + w _ 1 x, 0.5) $$

と設計します。そのコードは下記のようになります。

def model(x_data, y_data):
    w0 = pyro.sample("w0", dist.Normal(0, 10))
    w1 = pyro.sample("w1", dist.Normal(0, 10))
    # plate の第二引数はサンプルの数。 x_data のデータ数に合わせている。
    with pyro.plate("plate", x_data.size(0)):
        y_ = w0 + w1*x_data
        y = pyro.sample("y", dist.Normal(y_, 0.5), obs=y_data)
    return y

ここで、事前分布からサンプリングされるパラメータを用いた生成モデルで、データ点を再現してみましょう。

x_index = torch.linspace(0, 1, 100)
with torch.no_grad():
    y_prior = model(x_index, None)

plt.plot(x_data, y_data, "ro")
plt.plot(x_index.numpy(), y_prior.numpy(), "o")

f:id:s0sem0y:20191208174529p:plain

当然、実際のデータを再現するような生成モデルにはなっていません。

MAP推定

MAP推定は変分推論のインターフェースを流用し、変分近似分布としてデルタ分布を選べば実行できます。 最適化手法としては適当にAdamを選びますが、ほか何でも良いです。

guide = infer.autoguide.guides.AutoDelta(model)
adam_params = {"lr": 1e-1, "betas": (0.95, 0.999)}
oprimizer = pyro.optim.Adam(adam_params)

svi = infer.SVI(model, guide, oprimizer, loss=infer.JitTrace_ELBO())

変分推論のインターフェースが完成したら、あとは推論を回すだけです。svistep() メソッドがあり、先ほど作成したモデルの引数を渡してやれば実行できます。 モデルの中で使われるのはPyTorchのTensorであることに注意しましょう。NumPyのまま渡すとエラーが出ます。

n_steps = 5000
for step in range(n_steps):
    svi.step(x_data_th, y_data_th)

推論完了後はグローバルに配置されている変分パラメータを下記のコードで見ることができます。

for name in pyro.get_param_store():
    print("{}: {}".format(name, pyro.param(name)))
# AutoDelta.w0: 0.9983226656913757
# AutoDelta.w1: 3.928438186645508

ここで名前をしっかり把握することが大事です。自分で変分分布を書き下す場合には良いですが、AutoGuideクラスで楽をした場合は、内部でどのような名前になっているのかをこのようにして把握する必要があります(もちろん命名規則はある。AutoDelta を利用したならば、model() 関数内で hoge と名付けられたサンプルに対しては、 AutoDelta.hoge という変分パラメータが当てられている)。

デルタ分布の推論後の変分パラメータはMAP推定解そのものであるため、これを取り出して条件付き分布を作成する condition 関数を使って予測モデルを作ります。

w0map = pyro.param("AutoDelta.w0")
w1map = pyro.param("AutoDelta.w1")
y_given_w = pyro.condition(model, data={"w0": w0map, "w1": w1map})

たったこれだけで、予測モデルが得られました。早速予測モデルからデータをサンプリングしてみます。

with torch.no_grad():
    y_predict_fixed_w = y_given_w(x_index, None)

plt.plot(x_data, y_data, "ro")
plt.plot(x_index.numpy(), y_predict_fixed_w.numpy(), "o")
plt.legend(["obs_data", "generated_data"])

f:id:s0sem0y:20191208175408p:plain

モデルがデータを再現するような形になっているように見えます。しかし本当のデータに比べばらつきが大きいです。 これは、自分が立てた統計モデル

$$ y \sim {\rm Normal}(y\mid w _ 0 + w _ 1 x, 0.5) $$

において、標準偏差の見積もりが大きいからです(実際のデータは 0.1 である)。 これを解決するには、標準偏差の見積もりをやり直す、あるいは統計モデルの標準偏差にも事前分布を入れて推論してしまう、ということが必要です。 上記の手順を真似るだけですので、標準偏差も確率変数として扱ってモデルを立ててみるのは初歩的な練習になるでしょう。

もともとMAP推定は点推定であり、回帰直線を直接得ることができるはずです。そのような場合には、ばらつきを乗せる前の値を model() の戻り値にすればいいでしょう。 ただし多くの場合では、MAP推定は単なる状況の把握に使われるだけなので今回は割愛します(そもそもMAP推定が目的なら、あまりPyroを使う意義はない)。

変分推論

さて、次は変分推論を見ていきます。データとモデルは上記のものを流用するので、変分近似分布を作るところから始めます。

変分近似分布

これも既に準備されているものを使うだけなら数行で実施できます。

pyro.clear_param_store()
guide = infer.autoguide.guides.AutoDiagonalNormal(model)

ここで pyro.clear_param_store() はグローバルに配置された変分パラメータの辞書を空にする関数です。 Notebookなどで、いろいろなモデルを1つのプロセスで検討していると、他のモデルの変分パラメータが混じってしまっている状態になります。実際にはすべてのモデルで変分パラメータの名前がかぶらないようにしていれば問題はありませんが、念のため、いらないパラメータは消しておくようにします。

ここで構築された変分近似分布は分散共分散行列が対角行列になっているような多次元正規分布です。 今回の場合はパラメータ w _ 0w _ 1 に対して $\mathbf w = (w _ 0, w _ 1) ^ {\mathbf T}$ と二次元の確率変数ベクトルだとして

$$ q(w _ 0, w _ 1) = {\rm Normal} (\mathbf m , \mathbf \Sigma) $$

と設計したことになります。 ここで $\mathbf m = (m _ {w _ 0}, m _ {w _ 1} ) ^ {\mathbf T} $ と $\mathbf \Sigma = {\rm diag} (\sigma _ {w _ 0}, \sigma _ {w _ 1})$ です。

分散共分散行列が対角成分にしか値を持たないということは、各成分が互いに無相関だということを仮定していることになる。いやいや $w _ 0, w _ 1$ には相関があるはずだ、と考えるならば対角で無い分散共分散行列を持つ正規分布を仮定する必要がある。今回はこれでひとまず良いだろう。ちなみに、多次元正規分布の場合は無相関⇔独立が成り立つ(通常は無相関⇐独立しか成立しない)。したがって、対角な多次元正規分布を近似分布に設定するのは、個々のパラメータに個別に一次元正規分布を変分近似分布として採用することと変わりない。

推論

ここまでできたら事後分布の推論はインターフェースに則って

adam_params = {"lr": 1e-1, "betas": (0.95, 0.999)}
oprimizer = pyro.optim.Adam(adam_params)

svi = infer.SVI(model, guide, oprimizer, loss=infer.JitTrace_ELBO())

n_steps = 5000
for step in range(n_steps):
    svi.step(x_data_th, y_data_th)

で終わりです。素晴らしい。 推論後の変分パラメータを確認するために下記を実行します。

for name in pyro.get_param_store():
    print("{}: {}".format(name, pyro.param(name)))

# AutoDiagonalNormal.loc: Parameter containing:
# tensor([1.0168, 3.9421], requires_grad=True)
# AutoDiagonalNormal.scale: tensor([0.0290, 0.0360], grad_fn=<AddBackward0>)

ひとつ目が、成分を2つ持つTensorで、これが平均ベクトルになっています。ふたつ目が、分散共分散行列の対角成分となっています。 なるほど妥当な平均が推論されていますね。

予測

ここからがMAP推定とは違います。もしも上記で推論した平均だけを値として取り出して、条件付き分布を構成すれば良いだけです。 今回はパラメータ自体のばらつきも考慮したベイズ予測分布

$$ p(y _ {new} \mid x _ {new}) = \int _ {w _ 0, w _ 1} p(y _ {new} \mid x _ {new}, w _ 0, w _ 1) p(w _ 0, w _ 1 \mid D) \mathrm dw _ 0 \mathrm d w _ 1 $$

を獲得しましょう($D$ は推論に使った手元のデータである。すなわち第二因子はパラメータの事後分布だ。)。変分推論をした場合は、事後分布を変分近似分布で肩代わりしたのだから、

$$ p(y _ {new} \mid x _ {new}) = \int _ {w _ 0, w _ 1} p(y _ {new} \mid x _ {new}, w _ 0, w _ 1) q(w _ 0, w _ 1) \mathrm dw _ 0 \mathrm d w _ 1 $$

を使うことになり、この分布は infer.Predictive クラスで簡単に得られます。

predictivedist_y = infer.Predictive(model=model, guide=guide, 
                                    num_samples=1000, return_sites=["y"])

ここで、実際には積分計算ができないので、コンピューター上で得られる予測分布は下記のサンプリング

$$ \begin{align} w _ 0, w _ 1 & \sim q(w _ 0, w _ 1) \\ y _ {new} &\sim p(y _ {new} \mid x _ {new}, w _ 0, w _ 1) \end{align} $$

num_samples 回だけ実施したサンプルでモンテカルロ近似を実施していることになります。 return_sites 引数はこの分布の戻り値を model 内の確率変数の名前によって指定できます。これは、例えばあるサンプリング1回で、どのような w がサンプルされ、その結果 y がどうなっているのかを対応付けたい場合などに用いることができるでしょう。今回は予測分布だけに興味があるので y だけを取り出します。

with torch.no_grad():
    predictive_sample_y = predictivedist_y.get_samples(x_index, None)

によって、入力した x_index に対して指定した回数のサンプリングを実施し、指定した確率変数のサンプルが返ってきます。 この戻り値は辞書型になっており

predictive_sample_y["y"]

などと値を見ることができます。なぜこの仕組みなのかというと、Predictive の引数 return_sites に複数の確率変数を指定する場合があるためです。

predictive_sample_y["y"].shape # -> torch.Size([1000, 100])

となっており、一番目がモンテカルロサンプリングの回数であり、二番目が x_index のデータ数(バッチサイズ)です。 サンプリングによって得られた平均値と標準偏差を取り出して予測のばらつきを見てみましょう。

predict_mean = predictive_sample_y["y"].mean(0)
predict_std = predictive_sample_y["y"].std(0)

plt.plot(x_data, y_data, "ro")
plt.plot(x_index.numpy(), predict_mean.numpy())
plt.fill_between(x_index.numpy(),
                 predict_mean - 3*predict_std,
                 predict_mean + 3*predict_std,
                 alpha=0.2)

f:id:s0sem0y:20191208182932p:plain

ベイズ予測区間が得られました。

MCMC

次は事後分布の近似推論をMCMCによって実施します。 変分近似分布を仮定する場合には、先程の例で言えば「2つのパラメータは無相関である」ということと「正規分布している」という仮定が盛り込まれていました。事後分布の実態がどうであれ、そのような近似分布を選んだのであれば、その中で最もKLダイバージェンスの意味で近い分布しか得られません。

計算量などの都合が無い場合は、事後分布の近似推論にはMCMCを用いるほうが良い場合多いです。なぜならMCMCは原理的には事後分布の形状を最終的には得られるものとなっているからです。 (ただし、実際には事後分布の形状をちゃんとサンプリングできるような遷移状態に素早く収束させる方法などを考える必要がある)

モデルはこれまでと同様のものを流用します。

MCMCの実行

PyroにはNUTSとHMCが実装されています。これはモデルの同時確率さえ導出すれば計算を実行できる手法であるため、モデルを書き終えたら直ちに推論を開始できます(同時確率の計算はモデルをたてたあとはPyroがよしなにやってくれる)。

今回はMCMCの確率遷移核として infer.NUTS を利用します。 jit_compile が導入されたからなのか、 v0.3.0 時代より感覚的には遥かに高速になっているように感じました。 infer.MCMC に確率遷移核のインスタンス、サンプリングしてほしい個数、サンプリングが独立に実施され始めるであろうバーンアウト期間を渡します。NUTSstep_size=True としている場合には、バーンアウト期間をステップサイズの同定に利用します。

nuts_kernel = infer.NUTS(model, adapt_step_size=True,
                         jit_compile=True, ignore_jit_warnings=True)
mcmc = infer.MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
mcmc.run(x_data_th, y_data_th)

事後分布

推論終了後は

mcmc.summary()
'''

                mean       std    median      5.0%     95.0%     n_eff     r_hat
        w0      1.00      0.05      1.00      0.93      1.09    202.49      1.00
        w1      3.93      0.08      3.93      3.79      4.05    203.80      1.00
'''

と、各種統計量を確認することができます。また、

w0_mcmc = mcmc.get_samples()["w0"]
w1_mcmc = mcmc.get_samples()["w1"]

とサンプリングされた値を得ることができます。

予測分布

MCMCでサンプリングされた値を使って予測分布を得る場合もPredictive クラスを利用します。 変分推論の時は guide という引数と、モンテカルロ近似する際のサンプルサイズを指定していましたが、MCMCでは代わりに posterior_samples にサンプルを辞書で渡して使います。 (補足:もちろん、実際には変分推論のときにも guide から予め値をサンプリングして、そのサンプリングした値を今回のように posterior_samples に渡すこともできる)

predictivedist_y = infer.Predictive(model=model, 
                                    posterior_samples={"w0": w0_mcmc, "w1": w1_mcmc}, 
                                    return_sites=["y"])

あとはベイズ予測区間を表示してみます。

with torch.no_grad():
    predictive_sample_y = predictivedist_y.get_samples(x_index, None)

predict_mean = predictive_sample_y["y"].mean(0)
predict_std = predictive_sample_y["y"].std(0)

plt.plot(x_data, y_data, "ro")
plt.plot(x_index.numpy(), predict_mean.numpy())
plt.fill_between(x_index.numpy(),
                 predict_mean - 3*predict_std,
                 predict_mean + 3*predict_std,
                 alpha=0.2)

f:id:s0sem0y:20191208192746p:plain

変分推論とあまり変わらない結果になりましたが、もう少し複雑なモデルになると、一般にはMCMCの方が広い予測区間を持つ傾向にあります。

TensorFlow Probabilityに比べてインターフェースがだいぶ綺麗ですが、変分推論のカスタマイズをしたい場合などは、結局SVIクラスの中身を再構築する必要がありそうです。 また、MCMCも勾配ベースの手法しかなく、PyTorchの自動微分が使えるから実装されているくらいの感じでしょうか、そこまで力は入れていないように見えます。

www.hellocybernetics.tech