はじめに
今回は適当な人工データで多項式回帰を実施します。MAP推定、変分ベイズ推論と見ていきます。今回はそれぞれの学習方法が「どのような性能を達成するか」という視点ではなく、実装を通して「具体的に何を計算しているのか」を見ていくことにします。
今回は、PPLであるPyroを用いずに、PyTorchのみを使って実装を行います。そうするとMAP推定までは良いのですが、変分ベイズ推論が極めて面倒なコードを書かなければなりません。何を計算しなければならないのかを知っていれば、書く前から面倒だなということは分かるのですが、実際に書くことでPPLの恩恵を肌で感じつつ、変分ベイズ推論の動作を理解していきましょう。
ちなみに今回はPyTorchを使いますが、TensorFlow Probabilityのtfp.distributions
モジュールを使えばほとんど同様のコードで同じことができるはずです(TFPの変分推論は通常Edward2を用いるか、tfp.layers
モジュールを用いる)。
利用するモジュール
import numpy as np import matplotlib.pyplot as plt import torch import torch.distributions as torchdist
torch.distributions
の基本
分布の記述
まず、肩慣らしに標準正規分布 $$\rm {Normal} (0, 1)$$ を書いてみましょう。torch.distributions
モジュールを使うことで下記のように記述することができます。
normal_dist = torchdist.Normal(loc=torch.tensor(0.), scale=torch.tensor(1.))
サンプリング
続いて、記述した分布から $$ x \sim {\rm Normal}(0, 1) $$ とサンプリングを得るには下記のようにします。
x = normal_dist.sample()
添字 $i$ を用いて $x[i]$ として、個々の $x[i]$ をそれぞれ標準正規分布から独立にサンプルされたデータ
$$ X[i] \sim {\rm Normal}(0, 1) $$
としたければ下記の用にできます。ここでサンプルの数は10個としておきましょう。
X = normal_dist.sample([10])
また、上記の方法だと1次元配列に10個の要素があり、それぞれの要素が標準正規分布からのサンプルとなります。標準正規分布からのサンプルを例えば
$$ X[i, j] \sim {\rm Normal}(0, 1) $$
と、2次元配列に格納したければ下記のように書くことが出来ます。
X = normal_dist.sample([10, 5])
対数尤度の計算
対数尤度の計算は分布に対して実際に生成されているサンプルを渡せば実施できます。
sample = normal_dist.sample() log_likelihood = normal_dist.log_prob(sample)
これで、サンプル$x$の対数尤度が得られます。サンプルが複数ある場合、それぞれのサンプルの対数尤度を返してきます。
samples = normal_dist.sample([5]) log_likelihood = normal_dist.log_prob(samples) print(log_likelihood.shape) # -> (5,)
したがって、全てのサンプルの対数尤度の和はlog_likelihood.sum()
などで計算する必要があることに注意しましょう。
後の応用では推定を行う際に手持ちのデータの対数尤度を計算する必要があるので、形式的には下記のような計算が行われます。
samples = train_x
loss_value = - normal_dist.log_prob(samples).sum()
上記ではデータの負の対数尤度を計算しています。
MAP推定
用いるデータ
スカラーの入力 $x$ とスカラーの出力 $y$ の関係性が下記の用に多項式で表されるデータを生成します。
$$ y = -3 + 4x + x^2 + \epsilon $$
ただし、ここで $\epsilon \sim {\rm Normal}(0, 1)$ の正規乱数としておきます。
def toy_poly(): x = 5 * torch.rand(100, 1) linear_op = -3 - 4*x + 1*x**2 y = torchdist.Normal(linear_op, 1).sample() return x, y x_train, y_train = toy_poly() plt.plot(x_train.numpy(), y_train.numpy(), "o")
モデル
さて、上記のデータを見て、二次関数で表現できると睨んだとしましょう(出来レースですが)。
$$ y = w_0 + w_1x + w_2x^2 + \epsilon $$
ここで $\epsilon \sim {\rm Normal}(0, 1)$ の正規乱数としておきます。 すると、簡単な式変形から下記のような確率モデルを使うことが考えられます($\epsilon$の分散は既知としているが、未知にすることも当然必要であれば考えられる)。
$$ y - (w_0 + w_1x + w_2x^2) \sim {\rm Normal}(0,1) $$
あるいは、分布の平均の方にパラメータを持つ項を吸収させて
$$ y \sim {\rm Normal}(w_0 + w_1x + w_2x^2 ,1) $$
とできます。これで、求めたいのは $w_0, w_1, w_2$ というパラメータたちになりますが、これらのデータ $D = \{(x_1, y_1), \cdots, (x_N, y_N)\}$に基づく事後分布は下記のように書き下せます。
$$ \begin{align} p(w_0, w_1, w_2 \mid D) &= \frac{p(w_0, w_1, w_2, D)}{p(D)} \\ &= \frac{p(D\mid w_0, w_1, w_2)p(w_0, w_1, w_2)}{p(D)} \end{align} $$
ここで、事前分布を$p(w_0, w_1, w_2)$を適当に決めて、上記の式の分子のみを見て最大化を実施すれば良いのですが、ひとまず
$$ (w_1, w_2, w_3) \sim {\rm MultiNormal}({\bf 0}, {\rm diag}(\sigma_1^2, \sigma_2^2, \sigma_3^2)) $$
と分散共分散行列が対角行列の多変量正規分布を事前分布を適当に置いてしまいましょう。このように複数のパラメータを1つの多次元分布から生起させてもいいですし、個々に個別の分布を仮定しても構いません。今回のように対角行列が分散共分散行列となっているような多変量正規分布は各成分が無相関であり、多変量正規分布の場合は無相関と独立が同値になります。
したがって、実は上記の用に事前分布を置くことは、
$$ \begin{align} w_1 &\sim {\rm Normal} (0, \sigma_1) \\ w_2 &\sim {\rm Normal} (0, \sigma_2) \\ w_3 &\sim {\rm Normal} (0, \sigma_3) \end{align} $$
とするのと変わりません。すると結局最大化したいのは
$$ \begin{align} & p(D\mid w_0, w_1, w_2)p(w_0, w_1, w_2) \\ = & p(D\mid w_0, w_1, w_2)p(w_0)p(w_1)p(w_2) \end{align} $$
と表せます。
目的関数
定石通り、上記の対数を計算することにすれば
$$ \begin{align} {\rm LogJointProb}(w_0, w_1, w_2, X, Y) &= \frac{1}{N} \sum_{i=1}^N \log [{\rm Normal}(y_i\mid w_0 + w_1x_i + w_2x_i^2 ,1)] \\ & \ \ \ +{\rm log} [{\rm Normal}(w_0 \mid 0, \sigma_1) ] \\ & \ \ \ +{\rm log} [{\rm Normal}(w_1\mid 0, \sigma_2)] \\ & \ \ \ +{\rm log}[ {\rm Normal}(w_2 \mid 0, \sigma_3)] \end{align} $$
と書き表せます。ここで $\rm LogJointProb$ という命名は、結局MAP推定のときは事後分布の分子だけを見ており、ココはデータとパラメータの同時分布に他ならないからです。(グラフィカルモデルから考えていくと、確率変数を全て描いて同時分布を表しておいて、有向グラフとして確率変数の生成過程を記述するときに自然と同時分布から条件付き分布への形が導出される。)
こうしてモデルが対数同時分布として書き下せたのであれば、これを目的関数として最大化すれば良いです。普通は負を取って最小化問題とします。
コードでは関数log_joint_prob(w0, w1, w2, x, y)
内に事前分布と尤度として使う分布をと書き下しておき、return
でそれらの分布で計算される対数同時確率を返すようにします。引数である x, y
は訓練データであり、w0, w1, w2
が現在のパラメータの値ということになります。
def log_joint_prob(w0, w1, w2, x, y): prior_w0 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.)) prior_w1 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.)) prior_w2 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.)) linear = w0 + w1*x + w2*x**2 likelihood = torchdist.Normal(linear, torch.ones_like(linear)) return ( prior_w0.log_prob(w0) + prior_w1.log_prob(w1) + prior_w2.log_prob(w2) + likelihood.log_prob(y).mean() )
学習コード
さて、推定したいパラメータを宣言しておきます。
w0 = torch.nn.Parameter(torch.tensor(1.)) w1 = torch.nn.Parameter(torch.tensor(1.)) w2 = torch.nn.Parameter(torch.tensor(1.))
あとは普通にPyTorchの自動微分の機能を使って
optimizer = torch.optim.Adam(params=[w0, w1, w2], lr=1e-3) for i in range(30000): optimizer.zero_grad() log_joint_prob_value = log_joint_prob(w0, w1, w2, x_train, y_train) loss_value = - log_joint_prob_value loss_value.backward() optimizer.step() if (i+1) % 1000 == 0 or (i==0): print(loss_value.detach().numpy())
PyTorchっぽく書く
今まで関数で地道に表現した数式を、PyTorchのnn.Module
を使ってラッピングするだけです。事前分布とパラメータを nn.Module
にもたせてあげれば良いだけの簡単な仕様です。ここらへんはPyTorchがシンプルな設計になっているので書きやすいです。(tf.keras
モジュールやchainer
のようにModel
(Chain
)とLayer
(Link
)をhas-aにする思想になっていない)。
学習後は model.forward()
でフィッティングした予測曲線が得られます。
def toy_poly(): x = 5 * torch.rand(100, 1) linear_op = -3 - 4*x + 1*x**2 y = torchdist.Normal(linear_op, 1).sample() return x, y x_train, y_train = toy_poly() class MapRegression(torch.nn.Module): def __init__(self): super(MapRegression, self).__init__() self.prior_w0 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.)) self.prior_w1 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.)) self.prior_w2 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.)) self.w0 = torch.nn.Parameter(torch.tensor(1.)) self.w1 = torch.nn.Parameter(torch.tensor(1.)) self.w2 = torch.nn.Parameter(torch.tensor(1.)) def forward(self, x): return self.w0 + self.w1*x + self.w2*x**2 def log_joint_prob(self, x, y): linear = self.forward(x) likelihood = torchdist.Normal(linear, torch.ones_like(linear)) return ( self.prior_w0.log_prob(self.w0) + self.prior_w1.log_prob(self.w1) + self.prior_w2.log_prob(self.w2) + likelihood.log_prob(y).mean() ) model = MapRegression() optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3) for i in range(30000): optimizer.zero_grad() log_joint_prob_value = model.log_joint_prob(x_train, y_train) loss_value = - log_joint_prob_value loss_value.backward() optimizer.step() if (i+1) % 1000 == 0 or (i==0): print(loss_value.detach().numpy())
変分推論
ベイズ推論は事後分布
$$ p(\theta \mid D) = \frac{p(D\mid \theta)p(\theta)}{p(D)} $$
において、同時分布の最大化に甘んじず(すなわち関数の一番山となる部分だけを探すのではなく)、分布の形状全体を把握しようという試みになります。その試みは一般に困難を極めます。都合の良い尤度関数と都合の良い事前分布を選ばない限りは形状全体を上手に求めることはできません。
したがって、形状全体を知りたいのだが、ある程度簡略化した形状で一番近いものを探せればいいというのが変分推論です。変分モデル $q(\theta; \eta)$ を仮定し($\eta$は変分パラメータと呼ばれる最適化すべきパラメータである)、$\eta$ の調整で分布の形状を変えること $p(\theta |D)$ に最も近い $q(\theta; \eta)$ を決定します。近いというのはKLダイバージェンスの意味であり
$$ {\rm KL}[q(\theta; \eta) : p(\theta\mid D)] = {\mathbb E}_{q(\theta;\eta)}[{\rm log}q(\theta; \eta)] - {\mathbb E}_{q(\theta;\eta)}[{\rm log}p(\theta)] - {\mathbb E}_{q(\theta;\eta)}[{\rm log}p(D\mid\theta)] $$
を最小化するような $\eta$ を求めます。
こちらも最適化問題ではありますが、求めているものはパラメータ $\theta $ の値ではなくパラメータ $\theta$ が取りうる値の分布を網羅的に把握するために $\eta$ を最適化していることに注意しましょう。最適化された $\eta$ によって分布 $q(\theta ; \eta)$ が定まり、この分布からサンプリングをしたりすることで、単に点推定で $\theta$ を決めてしまうよりも多くの情報を利用することができるというわけです。
実際の推論では期待値計算の代わりに現在の $\eta$ の値を用いて
$$ \theta^* \sim q(\theta; \eta) $$
とサンプリングし、サンプリングされた$\theta^*$ で現在のKLダイバージェンスを計算するということにします(そんなのいい加減すぎる!と思うのであれば、$q(\theta; \eta)$ は現在の $\eta$ を使うとして重点サンプリングなどをしてもいいだろうし、大げさにもMCMCを使ってもいいだろう。単に計算量の問題である)。
変分モデル
回帰問題の例に戻って、パラメータ $w_0, w_1, w_2$ に対してそれぞれ変分モデル
$$ q(w_i ; \eta_i) = {\rm Normal}(\mu_i, \sigma_i) $$
を仮定しましょう。すなわち各 $w_i$ に対して正規分布を仮定して、あとはそれぞれの平均分散を変分パラメータとして最適化して $w_i$ の分布を得てしまおうということにしたのです。
ここで変分パラメータの標準偏差の代わりに、その対数を変分パラメータに取っています。理由は最適化の中で正負にも自由に値を取れるようにスケーリングがしたいからです。学習したパラメータの指数を取ってからサンプリングに使うようにすれば問題ありません。
variational_params = { "w0_loc": torch.nn.Parameter(torch.tensor(0.)), "w0_scale_log": torch.nn.Parameter(torch.tensor(0.)), "w1_loc": torch.nn.Parameter(torch.tensor(0.)), "w1_scale_log": torch.nn.Parameter(torch.tensor(0.)), "w2_loc": torch.nn.Parameter(torch.tensor(0.)), "w2_scale_log": torch.nn.Parameter(torch.tensor(0.)), } def variational_model(variational_params): """ Variational model q(w; eta) arg: variational parameters "eta" return: w ~ q(w; eta) """ w0_q = torchdist.Normal( variational_params["w0_loc"], torch.exp(variational_params["w0_scale_log"]), ) w1_q = torchdist.Normal( variational_params["w1_loc"], torch.exp(variational_params["w1_scale_log"]), ) w2_q = torchdist.Normal( variational_params["w2_loc"], torch.exp(variational_params["w2_scale_log"]), ) return w0_q, w1_q, w2_q
目的関数
目的関数はKLダイバージェンスになりますので、これを記述します。sample()
メソッドの代わりにrsample()
メソッドを利用していますが、これは計算グラフをトレースするときに必要になる情報を保持する場合に rsample()
を利用するようで、今回は標準偏差を対数スケールにしたりするなどの変換が行われているので、必要になるっぽいです。対数同時分布についてはMAP推定の関数を使い回すことが出来ます。
def kl_divergence(variational_params, x, y): w0_q, w1_q, w2_q = variational_model(variational_params) w0_sample = w0_q.rsample() w1_sample = w1_q.rsample() w2_sample = w2_q.rsample() log_joint_prob_value = log_joint_prob(w0_sample, w1_sample, w2_sample, x, y) log_variational_prob_value = ( w0_q.log_prob(w0_sample) + w1_q.log_prob(w1_sample) + w2_q.log_prob(w2_sample) ) return log_variational_prob_value - log_joint_prob_value
学習コード
optimizer = torch.optim.SGD(params=variational_params.values(), lr=1e-4) for i in range(9000): optimizer.zero_grad() loss_value =kl_divergence(variational_params, x_train, y_train) loss_value.backward() optimizer.step() if (i+1) % 300 == 0 or (i==0): print(loss_value.detach().numpy())
変分推論をPyTorchっぽく書く
基本的にモデルが変わらず保持しておく必要があるのは、事前分布と変分パラメータになります。 適宜、対数同時確率を計算する場合に必要であり、変分パラメータは、変分モデルからサンプリングをしたいあらゆる場面で必要になります。 注意が必要なのは、KLダイバージェンスを計算する際には、その1回の評価において尤度・事前分布・変分分布の対数尤度に同じサンプルが使われなければなりません。尤度の評価・事前分布の評価・変分分布の評価毎にサンプリングをしてしまったら、違うサンプルが使われてしまうので注意が必要です。
class VImodel(torch.nn.Module): def __init__(self): super(VImodel, self).__init__() self.prior_w0 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.)) self.prior_w1 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.)) self.prior_w2 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.)) self.w0_loc = torch.nn.Parameter(torch.tensor(0.)) self.w0_scale_log = torch.nn.Parameter(torch.tensor(0.)) self.w1_loc = torch.nn.Parameter(torch.tensor(0.)) self.w1_scale_log = torch.nn.Parameter(torch.tensor(0.)) self.w2_loc = torch.nn.Parameter(torch.tensor(0.)) self.w2_scale_log = torch.nn.Parameter(torch.tensor(0.)) def forward(self, x, sample_params=None): if sample_params==None: w0_q, w1_q, w2_q = self.variational_model() w0 = w0_q.rsample() w1 = w1_q.rsample() w2 = w2_q.rsample() else: w0, w1, w2 = sample_params linear = w0 + w1*x + w2*x**2 return linear def variational_model(self): w0_q = torchdist.Normal( self.w0_loc, torch.exp(self.w0_scale_log), ) w1_q = torchdist.Normal( self.w1_loc, torch.exp(self.w1_scale_log), ) w2_q = torchdist.Normal( self.w2_loc, torch.exp(self.w2_scale_log), ) return w0_q, w1_q, w2_q def log_joint_prob(self, x, y, sample_params): linear = self.forward(x, sample_params) likelihood = torchdist.Normal(linear, torch.ones_like(linear)) return ( self.prior_w0.log_prob(sample_params[0]) + self.prior_w1.log_prob(sample_params[1]) + self.prior_w2.log_prob(sample_params[2]) + likelihood.log_prob(y).sum() ) def kl_divergence(self, x, y): w0_q, w1_q, w2_q = self.variational_model() w0_sample = w0_q.rsample() w1_sample = w1_q.rsample() w2_sample = w2_q.rsample() sample_params = (w0_sample, w1_sample, w2_sample) log_joint_prob_value = self.log_joint_prob(x, y, sample_params) log_variational_prob_value = ( w0_q.log_prob(w0_sample) + w1_q.log_prob(w1_sample) + w2_q.log_prob(w2_sample) ) return log_variational_prob_value - log_joint_prob_value model = VImodel() optimizer = torch.optim.SGD(params=model.parameters(), lr=1e-4) for i in range(9000): optimizer.zero_grad() loss_value =model.kl_divergence(x_train, y_train) loss_value.backward() optimizer.step() if (i+1) % 300 == 0 or (i==0): print(loss_value.detach().numpy())
事後分布を求めたい変数が増えれば自ずと変分パラメータも増えますし、サンプリングの管理も大変になってくると思われます。そこらへんの管理をライブラリ側に任せつつ、条件付き分布なり期待値計算なりの便利なモジュール群も取り揃えてくれるのがpyro等のPPLになります。