HELLO CYBERNETICS

深層学習、機械学習、強化学習、信号処理、制御工学、量子計算などをテーマに扱っていきます

NumPyroとJax Numpyで時系列

 

 

follow us in feedly

はじめに

Pyroにはpyro.plate という(条件付き)独立のサンプリングをベクトル化してくれる便利な機能があります。そして更に、時系列的なサンプリングを実施してくれる pyro.marcov なる機能まであり、PPL万歳!という雰囲気なのですがいかんせんMCMCがアホみたいに遅いです。

NumPyroはJaxのNumPyをバックエンドにしたPyroという立ち位置で、謎の技術(誰か解説して…)によりPyTorchバックエンドのPyroよりも圧倒的にMCMCが速いという特徴があります。代わりにPyroにあったmarcovの機能であったり、離散的潜在変数をよしなに周辺化サンプリングしてくれる機能が無くなっております。

となると、結構HMMとかキツい…という話なのですが、HMMどころかそもそも時系列モデルを書くのがJaxバックエンドのために若干難しいという欠点もあります(なぜかと言うとJitの都合上、Pythonのfor loopを素朴に書いているとコンパイルがとにかく遅い)。

というわけで時系列モデルを書くためには jax.lax.scan関数を上手に使うということになるのですが、その練習も兼ねてひたすらコードを書き連ねてみます。

ちなみにタイムラインに

tjo.hatenablog.com

の話がちらっと出ており、この記事は交差検証の必要性を説く話なのですが、時系列データ(ランダムウォーク)に多項式フィッティング自体どうなんだ?という話がポロッと出ていたので、ランダムウォークのデータを扱ってみます。

ちなみにランダムウォークの今後を予測することは端から不可能なのですが、ベイズモデリングをしてみると不可能に近いほど不確実性が高いことを示してくれるので、やる意味が無いということは無いと思います。多分。

コード

モジュール

特に個別にコードの解説などはしません。 最低限のコードに絞っているつもりなので、NumPyroの基本的なことがわかっていれば読めるはずです。

www.hellocybernetics.tech

www.hellocybernetics.tech

www.hellocybernetics.tech

import jax.numpy as np
from jax import random, vmap, grad, jit, lax

import numpy as onp
import matplotlib.pyplot as plt

import numpyro
from numpyro import plate, sample, infer, handlers
import numpyro.distributions as dist

plt.style.use("seaborn")

データ

データ $X _ t$ をランダムウォークで扱います。

$$ X _ t = X _ {t - 1} + \epsilon _ t \\ \epsilon _ t \sim {\rm Normal}(0, \sigma) $$

こんな感じで、1個前の時刻の値にガウス分布から発生したノイズが加算されて次の時刻の値になります。 なので、i.i.d.でガウス雑音を生成しておいて np.cumsum でランダムウォークを簡単に作ることができます。

X = np.cumsum(0.5*onp.random.randn(100))
plt.plot(X)

f:id:s0sem0y:20200507234639p:plain

おー一見意味がありそうだからたちが悪い…!

モデル

モデルは下記のように書いておきます。 MCMCModel なんて親クラス、この記事では分けておく必要ないのですが、実際にはcolabで他のモデルもごちゃごちゃ書いて遊んでいたので、とりあえずこの形式で載せてしまいます。

class MCMCModel():

    def infer_model(self):
        pass

    def inference(self, warmup, n_samples):
        nuts = infer.NUTS(model.infer_model)
        self.mcmc = infer.MCMC(nuts, warmup, n_samples)
        self.mcmc.run(random.PRNGKey(0))        
        

class Model1(MCMCModel):
    def __init__(self, Xs, N):
        self.Xs = Xs
        self.N = N

    def infer_model(self):
        return self.model(self.Xs, self.N)

    def model(self, Xs, N):
        sigma_mu = sample("sigma_mu", dist.HalfNormal(10))
        sigma_X = sample("sigma_X", dist.HalfNormal(10))

        init_mu = sample("init_mu", dist.Normal(0, 1))
        epsilons = sample("epsilons", dist.Normal(np.zeros((N)), 
                                                sigma_mu*np.ones((N))))
        
        _, mu = lax.scan(self.f, init_mu, epsilons, N)
        numpyro.deterministic("mu", mu)
        X_sample = sample("X_sample", dist.Normal(mu, sigma_X), obs=Xs)
        return X_sample

    @staticmethod
    def f(carry, epsilon):
        carry = carry + epsilon
        return carry, carry

    def forecast(self, N_forecast):
        samples = self.mcmc.get_samples()
        # epsilons (N_forecast, samples)
        epsilons = sample("epsilons", dist.Normal(np.zeros((N_forecast, 1)), 
                                      samples["sigma_mu"]*np.ones((N_forecast, 1))))
        # init_mu (samples, ) 
        init_mu = samples["mu"][:, -1]
        # mu (N_forecast, samples)
        _, mu = lax.scan(self.f, init_mu, epsilons, N_forecast)
        return mu.T

f(carry, epsilon) が結構肝です。この関数はjax.lax.scan 関数に渡されます。大事なことは、A, C = f(A, B) という形式で関数を定義するということです。jax.lax.scan の振る舞いをPythonで表記すると下記のようになります。

def scan(f, init, xs, length=None):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys)

下から3行目で f が使われます。carryx を使って何らかの計算を実施し、carryy を返すような関数です。そして、f は forループの中におり、carry が次の f の引数に使われるということになります。そして forの中で戻り値 y は常にメモリに保存されているということにも注目しましょう。for を抜けたあとに、scan関数は 最新の carry とこれまですべての y を返します。for ループはscan 関数の引数である xs を一つずつ取り出して f に与えている点も注意しましょう。

今回のモデルを記述する上では、f(carry, x) を書くときに、carry が状態を x がガウスノイズを与えるようにします。

    @staticmethod
    def f(carry, epsilon):
        carry = carry + epsilon
        return carry, carry

そうするためには、ガウスノイズを scan の外でi.i.d.で全時系列分発生させておくことになります。

    def model(self, Xs, N):
        sigma_mu = sample("sigma_mu", dist.HalfNormal(10))
        sigma_X = sample("sigma_X", dist.HalfNormal(10))

        init_mu = sample("init_mu", dist.Normal(0, 1))
        ## N 個のガウス雑音を予め生成しておく
        epsilons = sample("epsilons", dist.Normal(np.zeros((N)), 
                                                sigma_mu*np.ones((N))))
        

        _, mu = lax.scan(self.f, init_mu, epsilons, N)
        numpyro.deterministic("mu", mu)
        X_sample = sample("X_sample", dist.Normal(mu, sigma_X), obs=Xs)
        return X_sample

lax.scan 関数で遊んでみるのが一番速いです。jaxでのLSTM等の実装も基本的にはこれを使うコトになると想われます(Jitを使う場合は…の話だが)。(なんでモデルが状態空間モデルになってるのか。X_sample なんて観測式が必要なのかどうかは微妙なところではあるのですが、実を言うと dist.Delta の使い方がいまいちよくわからなかったので、仕方無しに観測方程式を追加した。numpyro.deterministicsample("name", dist.Delta) の糖衣構文だと思っていたのだが、そうではないっぽいか。)

事前分布からのサンプリング

一応、推論を回す前の事前分布からのサンプリングでどんな時系列を見せてくれるのかだけ確認します。

model = Model1(X, len(X))
y = handlers.seed(model.model, rng_seed=0)(None, len(X))
plt.plot(y)

f:id:s0sem0y:20200507233822p:plain

すさまじい $\epsilon$ の分散を感じますね。

推論 結果

あとは実装してあるメソッドを殴るだけ。 30期先まで予測してみます。

model.inference(1000, 1000)
samples =model.mcmc.get_samples()

mu_forecast = handlers.seed(model.forecast, random.PRNGKey(0))(30)
mu_samples = np.hstack([model.mcmc.get_samples()["mu"], mu_forecast])
mu_mean = mu_samples.mean(axis=0)
mu_std = mu_samples.std(axis=0)

f:id:s0sem0y:20200507234751p:plain

予測している領域は、平均的にほぼ水平に動くとしながらも、信用区間はどんどん広がっていっています。