HELLO CYBERNETICS

深層学習、機械学習、強化学習、信号処理、制御工学、量子計算などをテーマに扱っていきます

Jaxでガウス過程 + NumPyroでハイパーパラメータ推論

 

 

follow us in feedly

モジュール

import jax.numpy as np
import jax
from jax import random, grad, vmap, jit, lax

import matplotlib.pyplot as plt
import seaborn as sns

import numpyro
import numpyro.distributions as dist
from numpyro import plate, sample, handlers
from numpyro.infer import MCMC, NUTS, SVI, ELBO

plt.style.use("seaborn")

key = random.PRNGKey(1)

データ

今回はsin波が指数減衰していくような関数を扱うことにします。

def toy_data(N, key=key):
    x = random.uniform(key, shape=(N, 1), minval=-2, maxval=2)
    y = np.sin(7*x) * np.exp(-0.5*x) + 0.5*onp.random.randn(N, 1)
    return x, y.squeeze()

x, y = toy_data(100)

f:id:s0sem0y:20200223141458p:plain

ガウス過程

ガウス過程はハイパーパラメータの推論を除けばすべて解析的に求まるため、jax.numpy を使って計算をすべて書いてしまいます。

カーネル関数

カーネル関数にはRBFを利用しますが、観測分散も一緒に取り込んでいる形式とそうでない形式の両方を準備します。後に予測を行う際、取り込んでいない形式を非対角のブロック行列の計算に用います。

@jit
def Kernel(X, Z, var, length):
    # distance between each rows
    dist_matrix = np.sum(np.square(X), axis=1).reshape(-1, 1)\
                + np.sum(np.square(Z), axis=1)\
                - 2 * np.dot(X, Z.T)
    return var * np.exp(-0.5 / length * dist_matrix)

@jit
def Kernel_with_noise(X, Z, var, length, noise, jitter=1e-6):
    return Kernel(X, Z, var, length) + np.eye(X.shape[0]) * (noise + jitter)

ちなみに、最初は素朴にグラム行列をfor文で書いたのですが、jitが異様に遅く使い物になりませんでした(この辺りはTensorFlowのtf.functionの凄みを感じた?)。その後、下記のjax.lax.fori_loop を利用しましたが、jitは良いとして自動微分に未対応のためハイパーパラメータの推論のときに詰まったためボツとしました。一応コードは下記のようになります。

# fori_loopが自動微分未対応?
# @jit
# def Kernel(X, Z, var, length):
#     I = X.shape[0]
#     J = Z.shape[0]

#     K = np.zeros(shape=[I, J])
#     def body(i, K):
        
#         def inner_body(j, K):
#             k_xz = rbf_kernel(X[i], Z[j], var, length)
#             K = jax.ops.index_update(
#                 K, jax.ops.index[i, j], k_xz
#             )
#             return K
        
#         K = lax.fori_loop(0, J, inner_body, K)
#         return K
        
#     K = lax.fori_loop(0, I, body, K)

#     return K

予測

グラム行列の計算手段があれば、あとは入力訓練データ、出力訓練データ、入力テストデータを用いて予測を直接書き下すことができます。これはパラメータの周辺化消去とカーネルトリックによって、パラメータが計算上から消え、通常の回帰モデルにおける学習結果と言えるものをデータで直接表現できるということです。

通常は、K_xx_inv の逆行列計算($O(n ^ 3)$)に多大なコストを要するため、様々な計算方法が提案されています。ここでは素朴にすべてのデータ点を用いて、逆行列を実直に計算する実装となっています。

def predict(rng_key, X, Y, X_test, var, length, noise):
    k_pp = Kernel_with_noise(X_test, X_test, var, length, noise)
    k_pX = Kernel(X_test, X, var, length)
    k_XX = Kernel_with_noise(X, X, var, length, noise)
    K_xx_inv = np.linalg.inv(k_XX)
    K = k_pp - np.matmul(k_pX, np.matmul(K_xx_inv, np.transpose(k_pX)))
    sigma_noise = np.sqrt(np.clip(np.diag(K), a_min=0.)) * jax.random.normal(rng_key, X_test.shape[:1])
    mean = np.matmul(k_pX, np.matmul(K_xx_inv, Y))
    return mean, mean + sigma_noise

決め打ちハイパーパラメータでの予測

データ点が2つでもとりあえずガウス過程は予測を出せます。今回はひとまず予測の平均値だけをプロットしてみます。

f:id:s0sem0y:20200223142356p:plain

次第にデータ点を増やしていきましょう。

f:id:s0sem0y:20200223142455p:plain

f:id:s0sem0y:20200223142542p:plain

f:id:s0sem0y:20200223142600p:plain

最終段階では真の関数も一緒に表示します。データ点自体が真の関数から誤差を持って観測されることに加え、ハイパーパラメータが決め打ちであるために、上手く予測ができていないようにも見えます。

f:id:s0sem0y:20200223142702p:plain

ガウス過程はデータ点を追加することで、いわゆる普通の回帰モデルにおける学習に相当する結果を直ちに得られます。一方でハイパーパラメータを調整するというのは、普通の回帰モデルで言うと多項式の次数を調整したり非線形の項を入れてみたりと、モデルそのものを調整する作業に相当します(なぜなら、ガウス過程のハイパーパラメータは関数がどれくらい曲がれるのに関わるからである)。

MCMC でのハイパーパラメータ推論

モデル

モデルは非常にシンプルに書くことができます。既にガウス過程における重要な計算は実装しているためです。やらなければならないことは、ハイパーパラメータに事前分布を与えてやることと、ガウス過程の定義である「どのような観測点のセットを選んでも、その観測点らが(データ数の次元での)多変量正規分布に従う」という形式でサンプリングを行うことです。

このときの共分散行列が、入力データ点によるカーネルグラム行列で記述されるのでした。

def model(X, Y):
    var = sample("kernel_var", dist.LogNormal(0.0, 10.0))
    noise = sample("kernel_noise", dist.LogNormal(0.0, 10.0))
    length = sample("kernel_length", dist.LogNormal(0.0, 10.0))

    K = Kernel_with_noise(X, X, var, length, noise)

    sample(
        "Y", 
        dist.MultivariateNormal(loc=np.zeros(X.shape[0]), covariance_matrix=K),
        obs=Y)

事前分布からのサンプリング

事前分布から選ばれたハイパーパラメータによってサンプリングを実施してみます。 これはよもや決め打ちのハイパーパラメータよりもいい加減な結果になることでしょう(事前分布の分散を見よ)。

trace_model = handlers.trace(handlers.seed(model, key))
prior_Y = trace_model.get_trace(random.normal(key, shape=(500, 1)), None)["Y"]

plt.plot(prior_Y["value"], "o")

f:id:s0sem0y:20200223143635p:plain

事後分布の推論

NumPyroでのMCMCは非常に明快なAPIとなっています。 今回のハイパーパラメータはすべて正に値を取るべき変数たちであるので、MCMCの内部で負の値がサンプリングされてしまうと不具合が生じます。そのため、通常では適当な変数変換によってMCMCの空間では実数全体を探索させておき、ハイパーパラメータとして使うときには適切な制約された空間に収まるように仕立て上げます。

NumPyroにおいてそれらの処理は、事前分布の定義域から自動で実施してくれるため特にユーザーが行う設定はありません。(TensorFlow Probabilityではtfp.bijectorsを用いて各々パラメータに対して適切な変換を準備する必要があります。)

def run_inference(model, warm_up, samples, key, X, Y):
    kernel = NUTS(model)
    mcmc = MCMC(kernel, warm_up, samples,
                progress_bar=True)
    mcmc.run(key, X, Y)
    mcmc.print_summary()
    return mcmc.get_samples()

samples = run_inference(model, 500, 1000, key, x_data, y_data)


sample: 100%|██████████| 1500/1500 [00:13<00:00, 113.48it/s, 11 steps of size 4.40e-01. acc. prob=0.93]

                     mean       std    median      5.0%     95.0%     n_eff     r_hat
  kernel_length      0.06      0.03      0.06      0.02      0.10    402.01      1.00
   kernel_noise      0.24      0.06      0.24      0.15      0.33    719.51      1.00
     kernel_var      3.85      3.57      2.84      0.82      7.20    386.16      1.00

Number of divergences: 0
plt.figure(figsize=(15, 4))
plt.subplot(131)
sns.distplot(samples["kernel_length"])
plt.title("length")
plt.subplot(132)
sns.distplot(samples["kernel_var"])
plt.title("var")
plt.subplot(133)
sns.distplot(samples["kernel_noise"])
plt.title("noise")

f:id:s0sem0y:20200223144059p:plain

予測分布

予測分布を出すためには、上記で得た事後分布(からのサンプリング)を用います。 今回は事後分布からのサンプリングを1000点準備しているので、ある1つの $x$ に対して $y$ が1000個計算できることになります。その結果から平均と分散を用いて簡易的に予測分布を表示することができます。

今回は vmap を用いて、既に実装されているpredict 関数をvectorizeします。 各々の予測において異なる乱数シードを用いるように、乱数シードも1000個準備することとします。

rng_key_predict = random.split(random.PRNGKey(0), num=1000)

predict 関数における、「バッチ処理が必要な変数」のみが引数となるように lambda 式で関数を作ります。

predict_samples = vmap(
    lambda rng_key, var, length, noise:
    predict(rng_key, x_data, y_data, x_test, var, length, noise)
)

準備が整いました。これで予測分布を計算してみます。

mean, mean_noise = predict_samples(
    rng_key_predict,
    samples["kernel_var"],
    samples["kernel_length"],
    samples["kernel_noise"]
)

mean_mean = mean_noise.mean(0)
mean_std = mean_noise.std(0)


plt.figure(figsize=(15, 6))
plt.plot(x_data, y_data, "o", alpha=0.5)
plt.plot(x_test, mean_mean, "b")
plt.fill_between(x_test.squeeze(), mean_mean-3*mean_std, mean_mean+3*mean_std, alpha=0.2)
plt.plot(x_true, y_true, "g")

f:id:s0sem0y:20200223144730p:plain

ハイパーパラメータの推論を実施したことで、決め打ちのときより遥かに良い予測ができているようです。

ガウス過程関連の記事

www.hellocybernetics.tech

www.hellocybernetics.tech

www.hellocybernetics.tech

www.hellocybernetics.tech