HELLO CYBERNETICS

深層学習、機械学習、強化学習、信号処理、制御工学、量子計算などをテーマに扱っていきます

Jax, PyTorch 直線探索付き勾配法

 

 

follow us in feedly

はじめに

前回は下記の記事で学習率固定で勾配法を実施しました。

www.hellocybernetics.tech

今回はウルフ条件を満たすような学習率を各更新時にバックステップで探索し、満たすものを見つけたら直ちにその学習率の更新するという形式で勾配法を実施します。

この記事ではJaxとPyTorchで収束までのステップ数や収束先等の結果はほぼ一致しましたが、速度が圧倒的にJaxの方が速く、PyTorchの計算グラフが変なふうになってしまっている可能性があります(こんなPyTorch遅いわけがない…!)

どなたか見つけたら教えて下さい…。

モジュールインポート

import jax
import jax.numpy as jnp
from jax import value_and_grad, jit, grad

import torch
import torch.autograd as ag

import matplotlib.pyplot as plt
import numpy as np

最適化問題はこんなもの

def objective(x):
    f = x[0]**4 + 2*x[0]**3 + 3*x[1]**2 + 2*x[0]*x[1] - x[1]
    return f

x1 = np.linspace(-4, 4)
x2 = np.linspace(-4, 4)
xx1, xx2 = np.meshgrid(x1, x2) 
f = objective([xx1, xx2])

x_init = jnp.array([3., 3.])
f_init = objective(x_init)

f:id:s0sem0y:20201104013513p:plain

Jax

勾配関数と線形探索関数を準備

勾配と評価値を得る関数は、value_and_grad に目的関数を渡せば良いだけ。自動微分バンザイです。 繰り返しiterationの中で利用されるpure関数なら何でもjitしておけば良いと思います。(使うのが一回とかだと返ってコンパイルがネックになるかも)

線形探索の方は学習率 a を引数に取る関数として書きます。これを自動微分関数を生成する grad に渡せば導関数が得られます。あとは探索を続ける条件を関数 cond(a) として書いておき、探索を続ける場合の処理も body(a) として書いておけば、学習率の初期値a_init を予め設定し jax.lax.while_loop(cond, body, a) に渡せば bodyTrue を返す限りbody を繰り返し続けます。

jit 前提ではPythonの構文は書かないでjax.lax の制御関数を利用しましょう。コンパイルの時間が全く異なります。

v_and_g = jit(value_and_grad(objective))

@jit
def linear_search(x, dx, tau1=0.5, tau2=0.8):
    a_init = 1.0
    beta = 0.999

    obj_a = lambda a: objective(x + dx*a)
    grad_obj_a = grad(obj_a)

    def armijo_cond(a):
        return obj_a(a) > obj_a(0.0) + tau1*grad_obj_a(0.0)*a

    def wolfe_cond(a):
        cond1 = armijo_cond(a)
        cond2 = grad_obj_a(a) < tau2*grad_obj_a(0.0)
        return cond1 + cond2
    
    body = lambda a: a*beta
    
    a = jax.lax.while_loop(wolfe_cond, body, a_init)
    return a

最適化実行

あとは普通の勾配法と同じです。学習率 a だけ各ステップに自動で決定されます。

x = jnp.array([3.0, 3.0])
tmp_val = 1e5

x_list = [x]
a_list = []
while True:
    val, grad_val = v_and_g(x)
    a = linear_search(x, -grad_val)
    x = x - a*grad_val
    a_list.append(a)
    x_list.append(x)
    if jnp.abs(val - tmp_val).sum() < 1e-3:
        break
    tmp_val = val

ちなみに10ステップで完了し、学習率の遷移はこんな感じです。ちなみに a = 1e-3 とかで決め打ちすると3桁ステップ掛かりました。しかも局所解に捕まる始末でございます。線形探索、単純だけど強力なのね。ただ、今回利用しているウルフ条件は、バックステップ法で条件を満たすものが見つかるとも限らなさそうで、実際やっていることは怪しい。もっと効率は悪いがアルミホ条件で妥協するほうが安全ではありそうです。

plt.plot(a_list)

f:id:s0sem0y:20201104014645p:plain

plt.figure(figsize=(7,7))
plt.contourf(xx1, xx2, f, cmap='Blues')
for i, x_trj in enumerate(x_list):
    plt.scatter(x_trj[0], x_trj[1], c="r", s=80)
    if i == len(x_list) - 1:
        plt.scatter(x_trj[0], x_trj[1], c="g", s=80)
        plt.title(f" iter: {len(x_list)} solution : {val}")

f:id:s0sem0y:20201104014730p:plain

PyTorch

線形探索関数準備

はい、torch.nn.Module 様とtorch.nn 様、及び PyTorch lightning 様に管理され、ほとんどネットワークを繋いでいるだけの使い方しかしていないため、細かい計算グラフの切り方など間違えている可能性があります。ご指摘願います。

Twitterにてご指摘をいただき a = a*betawith torch.no_grad() コンテキスト内に収めたら3倍高速化しました。

def linear_search(x, dx, tau1=0.5, tau2=0.8):
    beta = 0.999
    init_a = 1.

    obj_a = lambda a: objective(x + dx*a)
    zero = torch.zeros([], requires_grad=True)
    a = torch.tensor(init_a)
    obj_0_val = obj_a(zero)
    grad_obj_0_val = ag.grad(obj_0_val, zero)[0]

    def cond(a):
        a.requires_grad_()
        obj_a_val = obj_a(a)
        grad_obj_a_val = ag.grad(obj_a_val, a)[0]
        cond1 = obj_a_val <= obj_0_val + tau1*grad_obj_0_val*a
        cond2 = grad_obj_a_val >= tau2*grad_obj_0_val
        a.detach()
        return cond1 and cond2

    while True:
        if cond(a):
            break
        with torch.no_grad():
            a = a*beta
    return a.item()

最適化実行

x = torch.tensor([3.0, 3.0], requires_grad=True)
tmp_val = 1e5

x_list = [x.detach().numpy()]
a_list = []
while True:
    val = objective(x)
    grad_val = ag.grad(val, x)[0]
    a = linear_search(x.clone().detach(), -grad_val.clone().detach())
    x = x - a*grad_val
    x_list.append(x.detach().numpy())
    a_list.append(a)
    if torch.abs(val - tmp_val).item() < 1e-3:
        break
    tmp_val = val

可視化は同じなので結果だけ貼って省略します。

f:id:s0sem0y:20201104015619p:plain

結果

google colabで

Jax 0.365 sec

PyTorch 21 sec -> 7.8 sec

絶対なんかおかしい。Jax非同期ディスパッチでPythonでの計測がオカシイにしても、実際体感でJaxは1秒以内、PyTorchは20秒は掛かってた…。 PyTorch修正後 8秒未満まで短縮。 Jaxは jax.lax.while_loop を利用せず Pythonでループを回すと同様に8秒程掛かったことから、jax の制御式の実装がエゲツナク速い模様です(確かにiteration1回1回はかなり軽いので、条件判定等が最もボトルネックだったかもしれないことを考えると、学習ループの中に更に探索ループがいる場合、そこが利いてくるのはそのとおりかも)。