HELLO CYBERNETICS

深層学習、機械学習、強化学習、信号処理、制御工学、量子計算などをテーマに扱っていきます

【Jax NumPyro vs PyTorch Pyro】階層ベイズモデルMCMC対決

 

 

follow us in feedly

はじめに

最も使い慣れているPyTorchに周辺ライブラリが充実してきて、TensorFlow2系を追うのも完全に休止して内心喜んでいたところでございます。しかしそれも束の間、「PyroのMCMCおそすぎる…」問題に直撃してしまいました。もちろん遅いのは前から分かっていましたが、リリース版になりJitも充実してきたところでいつかは…と淡い期待を抱いていたのです。しかし、今も変わらず遅いままなのでNumPyroを触ってみました。

データ

今回は僕の実家の牧場が営んでいるアイスクリームの1ヶ月間の売上です。

これのモデルを書いてみます。 ちなみに興味はないと思いますが weather は 1 のときに雨を表しています。sells 単位は 百円 です。

    temparature weather sells
0  35.248615  0.0    132.44
1  37.528355  0.0    142.51
2  32.518875  0.0    127.70
3  26.970680  0.0    116.02
4  32.901375  1.0    85.37
5  34.352425  0.0    137.56
6  29.312223  1.0    67.04
7  30.941996  1.0    71.51
8  30.234177  0.0    104.16
9  29.689245  1.0    87.56
10 38.934875  0.0    139.69
11 34.785976  0.0    142.83
12 37.064690  0.0    156.58
13 32.293087  0.0    142.69
14 31.418299  0.0    129.36
15 30.507414  0.0    113.24
16 16.546255  0.0    57.07
17 35.479027  0.0    140.57
18 20.724728  0.0    71.65
19 20.487600  0.0    89.17
20 25.657394  0.0    94.38
21 26.641348  0.0    104.81
22 31.730848  0.0    121.97
23 36.050186  0.0    138.66
24 33.208298  0.0    117.27
25 34.133427  0.0    117.86
26 37.877991  0.0    160.74
27 27.431259  1.0    64.43
28 23.955915  0.0    101.83
29 32.517075  1.0    71.76
30 29.826220  0.0    90.90

きっとアイスクリームなので気温が高ければ高いほど売れるんでしょ?ということが想定できます。 ということでPlotしてみます。

f:id:s0sem0y:20200218181704p:plain

おー、キレイにそんな気がしますね。しかし、生データを眺めても天候が悪いと気温が高くてもアイスクリーム売れてません。そもそも牧場の来場者数が少ないのでしょう(来場者のデータもあればよかったんだけどね)。

とにかく、気温が高いほどアイスは売れそう、ということと、おそらくその傾向は大いにあるのだが、天候次第で売れ行きはかなりシフトするということがありそうです。ということでそれをモデルにします。

ちなみに天候によって気温は変わるのかということに関して、変わりそうなんだけどこのデータにはあまり相関がありませんでした(大きさが 0.1 未満だったよ)。

モデル

Pyro

def model_pyro(temp_obs, weather_obs, sells_obs):
    temp_means = sample(
        "temp_means", 
        dist.Normal(loc=torch.tensor(30.),
                    scale=torch.tensor(2.0))
    )
    temp_stds = sample(
        "temp_stds", 
        dist.LogNormal(loc=torch.tensor(0.),
                       scale=torch.tensor(2.))
    )

    sells_std = sample(
        "sells_std", 
        dist.LogNormal(loc=torch.tensor([0., 0.]),
                    scale=torch.tensor([5.0, 5.0]))
    )

    with plate("temps_", 2):
        temps = sample(
            "temps",
            dist.Normal(temp_means, temp_stds)
        )
    
    temp_coeff = sample("temp_coeff",
                        dist.Normal(torch.tensor([0., 0.]), torch.tensor([1., 1.])))
    temp_bias = sample("temp_bias",
                       dist.Normal(torch.tensor([0., 0.]), torch.tensor([10., 10.])))
    weather_prob = sample("weather_prob",
                         dist.Beta(1.0, 1.0))

    with plate("days", size=len(temp_obs)):
        weather = sample(
            "weather", 
            dist.Bernoulli(probs=weather_prob).expand([len(temp_obs)]),
            obs=weather_obs
        ).long()
        temp = sample(
            "temp", 
            dist.Normal(loc=temps[weather], scale=temp_stds),
            obs=temp_obs
        )
        sells = sample(
            "sells",
            dist.Normal(loc=temp_coeff[weather] * temp + temp_bias[weather], 
                        scale=sells_std[weather]),
            obs=sells_obs
        )
    return temp, weather, sells

モデルはまあまあ適当ですが、気温から売上への回帰モデルが天候毎に切り替わる感じです。あとは売上のばらつきも天候毎に違うことにしています。

nuts = infer.mcmc.NUTS(model_pyro, jit_compile=True, ignore_jit_warnings=True,
                       max_tree_depth=10)
mcmc = infer.mcmc.MCMC(nuts, 1000, 200)
mcmc.run(temp, wether.float(), sells)

jitもしっかり入れて標準的なパラメータ設定でNUTSを準備しました。そして200サンプルでwarm up、Adaptive step sizeを利用しこの間にstep sizeを調整しました。その後1000サンプルを得るまでMCMCを回します。

結果的に2分30秒を要しました。おそすぎや…。

NumPyro

NumPyroはバックエンドがJaxになっているだけで、基本的なインターフェースは同じです。ただし、jit前提で動くため後に示すように、モデル内部の変数、例えば plate("name", len(data)) などは最初の評価時に len(data) = 30 であれば、30 でその後も固定されてしまっています。要するに len() という関数がコンパイルされており、評価をその都度してくれるわけではなさそうです。多分…。

なのでこの部分を可変にしたい場合は、model(data, num_data) などとして外から与えるようにしましょう。(本当にそうしなきゃいけないかは要確認)

def model(temp_obs, weather_obs, sells_obs):

    temp_means = sample(
        "temp_means", 
        dist.Normal(loc=jnp.array(30.),
                    scale=jnp.array(2.0)),
    )
    temp_stds = sample(
        "temp_stds", 
        dist.LogNormal(loc=jnp.array(0.),
                       scale=jnp.array(2.)),
    )

    sells_std = sample(
        "sells_std", 
        dist.LogNormal(loc=jnp.array([0., 0.]),
                    scale=jnp.array([5.0, 5.0])),
    )

    with plate("temps_", 2):
        temps = sample(
            "temps",
            dist.Normal(temp_means, temp_stds),
        )
    
    temp_coeff = sample("temp_coeff",
                        dist.Normal(jnp.array([0., 0.]), jnp.array([1., 1.])),
                        )
    temp_bias = sample("temp_bias",
                       dist.Normal(jnp.array([0., 0.]), jnp.array([100., 100.])),
                       )
    weather_prob = sample("weather_prob",
                         dist.Beta(1.0, 1.0),
                         )

    with plate("days", size=len(temp_data)):
        weather = sample(
            "weather", 
            dist.Bernoulli(probs=weather_prob),
            obs=weather_obs,
        ).astype(jnp.int32)

        temp = sample(
            "temp", 
            dist.Normal(loc=temps[weather], scale=temp_stds),
            obs=temp_obs,
        )

        sells = sample(
            "sells",
            dist.Normal(loc=temp_coeff[weather] * temp + temp_bias[weather], 
                           scale=sells_std[weather]),
            obs=sells_obs,
        )
    return temp, weather, sells


mcmc = infer.MCMC(infer.NUTS(model, max_tree_depth=10), 
                  num_warmup=200, 
                  num_samples=1000, 
                  progress_bar=True,)
mcmc.run(random.PRNGKey(2019), temp_data, weather_data, sells_data)

Jax版ではjax.random.PRNGKey が乱数を利用する場合に必要になります。この値を固定している限りは、完全に再現性を担保してくれるようです(もはや乱数じゃなく、何らかの写像になっている)。そんなことはともかく、

肝心な計算時間は11秒。JIT含めても15秒程。

もはやNumPyroが速いのかPyroがおそすぎるのかよく分かりませんが、圧倒的な差を見せつけられたのでPyroはGpyTorchなどと連携するとき以外は封印することにします。

おまけ(推論結果)

とりあえず簡単に、回帰モデルの気温の売上への効きは事後分布coeff を確認すると分かり、

f:id:s0sem0y:20200218184519p:plain

という具合でした。(緑が雨、青が晴れ)

雨の日はどうも、気温が高かろうが低かろうがあんまり効いていない感じがします(大体買う気無いということか?)。晴天時は正の値に分布がいるので、気温が高い程買ってくれる人が多そうですね。

30日間の売上に対する予測分布はこんな感じになっています。モデルさん自信無さすぎかもね。

f:id:s0sem0y:20200218184726p:plain

ちなみに実家が牧場というのは嘘です。