HELLO CYBERNETICS

深層学習、機械学習、強化学習、信号処理、制御工学、量子計算などをテーマに扱っていきます

TensorFlow probabilityでレプリカ交換モンテカルロ法

 

 

follow us in feedly

はじめに

これはもはやただの備忘録です。 Pyroに無い機能をこっちで動かす確認をしたかったというだけでございます。

import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
from tensorflow_probability import bijectors as tfb

import numpy as np
import matplotlib.pyplot as plt

plt.style.use("seaborn")

を前提とします。

データ

def toy_data():
    sigma1 = 1.5
    mu1 = 10.0
    dist1 = tfd.Normal(loc=mu1, scale=sigma1)
    
    sigma2 = 3.0
    mu2 = 2.0
    dist2 = tfd.Normal(loc=mu2, scale=sigma2)
    return tf.concat([dist1.sample(100), dist2.sample(200)], axis=0)

やっつけでデータを作成しております。個々のデータは単に正規分布で、それが1:2の割合で混ざっています。 見かけ上は負担率1/3の混合正規分布のようなヒストグラムになります。

f:id:s0sem0y:20191026081507p:plain

モデル

ということで、正規分布2つが混ざっている混合モデルを適当な事前分布を仮定して作ってみます。 データを見ればおおよそ各々の平均や分散が見積もれるので、その付近にばらつきを持った事前分布を仮定しました。

標準偏差は正の値であるべきですので、ガウス分布から生成した log_sigmatf.math.exp に食わしてパラメータとして与えます。 負担率は 0から1の値ですので、logit を作ってから tf.math.sigmoid に食わせます。再パラメータ化はMCMCの結果に影響するので注意が必要です。 あまり非線形性の強い変換(曲率の大きな変換)はその付近でサンプリングが敏感に変わってしまうことになるでしょう。

再パラメータ化せずに、分散に対しては正の値しかサンプリングしない逆ガンマ分布であったり、負担率であれば0から1しか値を取らないベータ分布を使う方法もあります。

root = tfd.JointDistributionCoroutine.Root
def model():
    mu1 = yield root(tfd.Normal(loc=3.0, scale=3))
    log_sigma1 = yield root(tfd.Normal(loc=0, scale=2))
    sigma1 = tf.math.exp(log_sigma1)
    
    mu2 = yield root(tfd.Normal(loc=10, scale=3))
    log_sigma2 = yield root(tfd.Normal(loc=0, scale=2))
    sigma2 = tf.math.exp(log_sigma2)

    logit = yield root(tfd.Normal(loc=0, scale=1.0))
    prob = tf.math.sigmoid(logit)
    
    x = yield tfd.Sample(
        tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(probs=[prob, 1-prob]),
            components_distribution=tfd.Normal(
                loc=[mu1, mu2],
                scale=[sigma1, sigma2],
            )
        ),
        sample_shape=300
    )
    
joint = tfd.JointDistributionCoroutine(model)

上記のモデルの作り方の基本は下記で

www.hellocybernetics.tech

遷移核

サンプリングしたい分布は同時分布に対して、観測データ sample を与えた状態にすれば得られます。 サンプリングを実施するための遷移核にはNUTSを使い、レプリカ交換モンテカルロ法のクラスでラッピングします。

この時、make_kernel_fn にはtarget_log_prob_fn と書かれているサンプリングしたい分布と、ランダムシードを引数にしておきます。 なぜかランダムシードの引数がないと、後で怒られます。

レプリカ交換モンテカルロのクラスでラップするときに、seed=tf.random.set_seed()を利用しなければなりません。intを直接与えず、こうしておかないとTF2.0が怒ります。謎。 しかもintを与えた時に出てくるエラーメッセージは、存在しない架空の(過去にはあったのだろうが)メソッドを使えと表示されるため、解決方法を見つけるまでハマりました。

sample = toy_data()

def unnormalized_log_prob(mu1, log_sigma1, mu2, log_sigma2, logit):
    return tf.reduce_mean(joint.log_prob([mu1, log_sigma1, mu2, log_sigma2,
                                          logit, sample]))


def make_kernel_fn(target_log_prob_fn, seed):
    return tfp.mcmc.NoUTurnSampler(
            target_log_prob_fn=target_log_prob_fn,
            step_size=0.05,
            seed=seed
    ) 
     

remc = tfp.mcmc.ReplicaExchangeMC(
    target_log_prob_fn=unnormalized_log_prob,
    inverse_temperatures=[tf.constant(0.2), 
                          tf.constant(0.2), 
                          tf.constant(0.2), 
                          tf.constant(0.2), 
                          tf.constant(0.2)],
    make_kernel_fn=make_kernel_fn,
    seed=tf.random.set_seed(1)
)

あとは回すだけ

@tf.function()
def run_chain():
    with tf.device("/gpu:0"):
        init_state = list(joint.sample()[:-1]) 
        chains_states, kernels_results = tfp.mcmc.sample_chain(
            num_results=1000,
            num_burnin_steps=300,
            current_state=init_state,
            kernel=remc,
            parallel_iterations=50
        )
    return chains_states, kernels_results



chain_states, kernel_results = run_chain()

言っておきますがめちゃくちゃ遅いです。いや、一次元なんだからGPU使う必要はなかったのだろうか…?よくわからないけど、どっちにしてもアクビが出るほど遅かったです(多峰が相手になったときもNUTSを初期値変えてサンプリングを複数回繰り返して事後分布を結合したほうがマシなんじゃないか?くらい遅い)。

ひとまず備忘録なのでここまで。